diff --git a/common/bidi/spec/BUILD.bazel b/common/bidi/spec/BUILD.bazel new file mode 100644 index 0000000000000..74c3cffd35ed0 --- /dev/null +++ b/common/bidi/spec/BUILD.bazel @@ -0,0 +1,13 @@ +package( + default_visibility = [ + "//py:__pkg__", + ], +) + +exports_files( + srcs = [ + "all.cddl", + "local.cddl", + "remote.cddl", + ], +) diff --git a/common/bidi/spec/all.cddl b/common/bidi/spec/all.cddl new file mode 100644 index 0000000000000..85c4536a2cd10 --- /dev/null +++ b/common/bidi/spec/all.cddl @@ -0,0 +1,2489 @@ +Command = { + id: js-uint, + CommandData, + Extensible, +} + +CommandData = ( + BrowserCommand // + BrowsingContextCommand // + EmulationCommand // + InputCommand // + NetworkCommand // + ScriptCommand // + SessionCommand // + StorageCommand // + WebExtensionCommand +) + +EmptyParams = { + Extensible +} + +Message = ( + CommandResponse / + ErrorResponse / + Event +) + +CommandResponse = { + type: "success", + id: js-uint, + result: ResultData, + Extensible +} + +ErrorResponse = { + type: "error", + id: js-uint / null, + error: ErrorCode, + message: text, + ? stacktrace: text, + Extensible +} + +ResultData = ( + BrowserResult / + BrowsingContextResult / + EmulationResult / + InputResult / + NetworkResult / + ScriptResult / + SessionResult / + StorageResult / + WebExtensionResult +) + +EmptyResult = { + Extensible +} + +Event = { + type: "event", + EventData, + Extensible +} + +EventData = ( + BrowsingContextEvent // + InputEvent // + LogEvent // + NetworkEvent // + ScriptEvent +) + +Extensible = (*text => any) + +js-int = -9007199254740991..9007199254740991 +js-uint = 0..9007199254740991 + +ErrorCode = "invalid argument" / + "invalid selector" / + "invalid session id" / + "invalid web extension" / + "move target out of bounds" / + "no such alert" / + "no such network collector" / + "no such element" / + "no such frame" / + "no such handle" / + "no such history entry" / + "no such intercept" / + "no such network data" / + "no such node" / + "no such request" / + "no such script" / + "no such storage partition" / + "no such user context" / + "no such web extension" / + "session not created" / + "unable to capture screen" / + "unable to close browser" / + "unable to set cookie" / + "unable to set file input" / + "unavailable network data" / + "underspecified storage partition" / + "unknown command" / + "unknown error" / + "unsupported operation" + +SessionCommand = ( + session.End // + session.New // + session.Status // + session.Subscribe // + session.Unsubscribe +) + +SessionResult = ( + session.EndResult / + session.NewResult / + session.StatusResult / + session.SubscribeResult / + session.UnsubscribeResult +) + +session.CapabilitiesRequest = { + ? alwaysMatch: session.CapabilityRequest, + ? firstMatch: [*session.CapabilityRequest] +} + +session.CapabilityRequest = { + ? acceptInsecureCerts: bool, + ? browserName: text, + ? browserVersion: text, + ? platformName: text, + ? proxy: session.ProxyConfiguration, + ? unhandledPromptBehavior: session.UserPromptHandler, + Extensible +} + +session.ProxyConfiguration = { + session.AutodetectProxyConfiguration // + session.DirectProxyConfiguration // + session.ManualProxyConfiguration // + session.PacProxyConfiguration // + session.SystemProxyConfiguration +} + +session.AutodetectProxyConfiguration = ( + proxyType: "autodetect", + Extensible +) + +session.DirectProxyConfiguration = ( + proxyType: "direct", + Extensible +) + +session.ManualProxyConfiguration = ( + proxyType: "manual", + ? httpProxy: text, + ? sslProxy: text, + ? session.SocksProxyConfiguration, + ? noProxy: [*text], + Extensible +) + +session.SocksProxyConfiguration = ( + socksProxy: text, + socksVersion: 0..255, +) + +session.PacProxyConfiguration = ( + proxyType: "pac", + proxyAutoconfigUrl: text, + Extensible +) + +session.SystemProxyConfiguration = ( + proxyType: "system", + Extensible +) + + +session.UserPromptHandler = { + ? alert: session.UserPromptHandlerType, + ? beforeUnload: session.UserPromptHandlerType, + ? confirm: session.UserPromptHandlerType, + ? default: session.UserPromptHandlerType, + ? file: session.UserPromptHandlerType, + ? prompt: session.UserPromptHandlerType, +} + +session.UserPromptHandlerType = "accept" / "dismiss" / "ignore"; + +session.Subscription = text + +session.SubscribeParameters = { + events: [+text], + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +session.UnsubscribeByIDRequest = { + subscriptions: [+session.Subscription], +} + +session.UnsubscribeByAttributesRequest = { + events: [+text], +} + +session.Status = ( + method: "session.status", + params: EmptyParams, +) + +session.StatusResult = { + ready: bool, + message: text, +} + +session.New = ( + method: "session.new", + params: session.NewParameters +) + +session.NewParameters = { + capabilities: session.CapabilitiesRequest +} + +session.NewResult = { + sessionId: text, + capabilities: { + acceptInsecureCerts: bool, + browserName: text, + browserVersion: text, + platformName: text, + setWindowRect: bool, + userAgent: text, + ? proxy: session.ProxyConfiguration, + ? unhandledPromptBehavior: session.UserPromptHandler, + ? webSocketUrl: text, + Extensible + } +} + +session.End = ( + method: "session.end", + params: EmptyParams +) + + +session.EndResult = EmptyResult + +session.Subscribe = ( + method: "session.subscribe", + params: session.SubscribeParameters +) + +session.SubscribeResult = { + subscription: session.Subscription, +} + +session.Unsubscribe = ( + method: "session.unsubscribe", + params: session.UnsubscribeParameters, +) + +session.UnsubscribeParameters = session.UnsubscribeByAttributesRequest / session.UnsubscribeByIDRequest + +session.UnsubscribeResult = EmptyResult + +BrowserCommand = ( + browser.Close // + browser.CreateUserContext // + browser.GetClientWindows // + browser.GetUserContexts // + browser.RemoveUserContext // + browser.SetClientWindowState // + browser.SetDownloadBehavior +) + +BrowserResult = ( + browser.CloseResult / + browser.CreateUserContextResult / + browser.GetClientWindowsResult / + browser.GetUserContextsResult / + browser.RemoveUserContextResult / + browser.SetClientWindowStateResult / + browser.SetDownloadBehaviorResult +) + +browser.ClientWindow = text; + +browser.ClientWindowInfo = { + active: bool, + clientWindow: browser.ClientWindow, + height: js-uint, + state: "fullscreen" / "maximized" / "minimized" / "normal", + width: js-uint, + x: js-int, + y: js-int, +} + +browser.UserContext = text; + +browser.UserContextInfo = { + userContext: browser.UserContext +} + +browser.Close = ( + method: "browser.close", + params: EmptyParams, +) + +browser.CloseResult = EmptyResult + +browser.CreateUserContext = ( + method: "browser.createUserContext", + params: browser.CreateUserContextParameters, +) + +browser.CreateUserContextParameters = { + ? acceptInsecureCerts: bool, + ? proxy: session.ProxyConfiguration, + ? unhandledPromptBehavior: session.UserPromptHandler +} + +browser.CreateUserContextResult = browser.UserContextInfo + +browser.GetClientWindows = ( + method: "browser.getClientWindows", + params: EmptyParams, +) + +browser.GetClientWindowsResult = { + clientWindows: [ * browser.ClientWindowInfo] +} + +browser.GetUserContexts = ( + method: "browser.getUserContexts", + params: EmptyParams, +) + +browser.GetUserContextsResult = { + userContexts: [ + browser.UserContextInfo] +} + +browser.RemoveUserContext = ( + method: "browser.removeUserContext", + params: browser.RemoveUserContextParameters +) + +browser.RemoveUserContextParameters = { + userContext: browser.UserContext +} + +browser.RemoveUserContextResult = EmptyResult + +browser.SetClientWindowState = ( + method: "browser.setClientWindowState", + params: browser.SetClientWindowStateParameters +) + +browser.SetClientWindowStateParameters = { + clientWindow: browser.ClientWindow, + (browser.ClientWindowNamedState // browser.ClientWindowRectState) +} + +browser.ClientWindowNamedState = ( + state: "fullscreen" / "maximized" / "minimized" +) + +browser.ClientWindowRectState = ( + state: "normal", + ? width: js-uint, + ? height: js-uint, + ? x: js-int, + ? y: js-int, +) + +browser.SetClientWindowStateResult = browser.ClientWindowInfo + +browser.SetDownloadBehavior = ( + method: "browser.setDownloadBehavior", + params: browser.SetDownloadBehaviorParameters +) + +browser.SetDownloadBehaviorParameters = { + downloadBehavior: browser.DownloadBehavior / null, + ? userContexts: [+browser.UserContext] +} + +browser.DownloadBehavior = { + ( + browser.DownloadBehaviorAllowed // + browser.DownloadBehaviorDenied + ) +} + +browser.DownloadBehaviorAllowed = ( + type: "allowed", + destinationFolder: text +) + +browser.DownloadBehaviorDenied = ( + type: "denied" +) + +browser.SetDownloadBehaviorResult = EmptyResult + +BrowsingContextCommand = ( + browsingContext.Activate // + browsingContext.CaptureScreenshot // + browsingContext.Close // + browsingContext.Create // + browsingContext.GetTree // + browsingContext.HandleUserPrompt // + browsingContext.LocateNodes // + browsingContext.Navigate // + browsingContext.Print // + browsingContext.Reload // + browsingContext.SetViewport // + browsingContext.TraverseHistory +) + +BrowsingContextResult = ( + browsingContext.ActivateResult / + browsingContext.CaptureScreenshotResult / + browsingContext.CloseResult / + browsingContext.CreateResult / + browsingContext.GetTreeResult / + browsingContext.HandleUserPromptResult / + browsingContext.LocateNodesResult / + browsingContext.NavigateResult / + browsingContext.PrintResult / + browsingContext.ReloadResult / + browsingContext.SetViewportResult / + browsingContext.TraverseHistoryResult +) + +BrowsingContextEvent = ( + browsingContext.ContextCreated // + browsingContext.ContextDestroyed // + browsingContext.DomContentLoaded // + browsingContext.DownloadEnd // + browsingContext.DownloadWillBegin // + browsingContext.FragmentNavigated // + browsingContext.HistoryUpdated // + browsingContext.Load // + browsingContext.NavigationAborted // + browsingContext.NavigationCommitted // + browsingContext.NavigationFailed // + browsingContext.NavigationStarted // + browsingContext.UserPromptClosed // + browsingContext.UserPromptOpened +) + +browsingContext.BrowsingContext = text; + +browsingContext.InfoList = [*browsingContext.Info] + +browsingContext.Info = { + children: browsingContext.InfoList / null, + clientWindow: browser.ClientWindow, + context: browsingContext.BrowsingContext, + originalOpener: browsingContext.BrowsingContext / null, + url: text, + userContext: browser.UserContext, + ? parent: browsingContext.BrowsingContext / null, +} + +browsingContext.Locator = ( + browsingContext.AccessibilityLocator / + browsingContext.CssLocator / + browsingContext.ContextLocator / + browsingContext.InnerTextLocator / + browsingContext.XPathLocator +) + +browsingContext.AccessibilityLocator = { + type: "accessibility", + value: { + ? name: text, + ? role: text, + } +} + +browsingContext.CssLocator = { + type: "css", + value: text +} + +browsingContext.ContextLocator = { + type: "context", + value: { + context: browsingContext.BrowsingContext, + } +} + +browsingContext.InnerTextLocator = { + type: "innerText", + value: text, + ? ignoreCase: bool + ? matchType: "full" / "partial", + ? maxDepth: js-uint, +} + +browsingContext.XPathLocator = { + type: "xpath", + value: text +} + +browsingContext.Navigation = text; + +browsingContext.BaseNavigationInfo = ( + context: browsingContext.BrowsingContext, + navigation: browsingContext.Navigation / null, + timestamp: js-uint, + url: text, +) + +browsingContext.NavigationInfo = { + browsingContext.BaseNavigationInfo +} + +browsingContext.ReadinessState = "none" / "interactive" / "complete" + +browsingContext.UserPromptType = "alert" / "beforeunload" / "confirm" / "prompt"; + +browsingContext.Activate = ( + method: "browsingContext.activate", + params: browsingContext.ActivateParameters +) + +browsingContext.ActivateParameters = { + context: browsingContext.BrowsingContext +} + +browsingContext.ActivateResult = EmptyResult + +browsingContext.CaptureScreenshot = ( + method: "browsingContext.captureScreenshot", + params: browsingContext.CaptureScreenshotParameters +) + +browsingContext.CaptureScreenshotParameters = { + context: browsingContext.BrowsingContext, + ? origin: ("viewport" / "document") .default "viewport", + ? format: browsingContext.ImageFormat, + ? clip: browsingContext.ClipRectangle, +} + +browsingContext.ImageFormat = { + type: text, + ? quality: 0.0..1.0, +} + +browsingContext.ClipRectangle = ( + browsingContext.BoxClipRectangle / + browsingContext.ElementClipRectangle +) + +browsingContext.ElementClipRectangle = { + type: "element", + element: script.SharedReference +} + +browsingContext.BoxClipRectangle = { + type: "box", + x: float, + y: float, + width: float, + height: float +} + +browsingContext.CaptureScreenshotResult = { + data: text +} + +browsingContext.Close = ( + method: "browsingContext.close", + params: browsingContext.CloseParameters +) + +browsingContext.CloseParameters = { + context: browsingContext.BrowsingContext, + ? promptUnload: bool .default false +} + +browsingContext.CloseResult = EmptyResult + +browsingContext.Create = ( + method: "browsingContext.create", + params: browsingContext.CreateParameters +) + +browsingContext.CreateType = "tab" / "window" + +browsingContext.CreateParameters = { + type: browsingContext.CreateType, + ? referenceContext: browsingContext.BrowsingContext, + ? background: bool .default false, + ? userContext: browser.UserContext +} + +browsingContext.CreateResult = { + context: browsingContext.BrowsingContext +} + +browsingContext.GetTree = ( + method: "browsingContext.getTree", + params: browsingContext.GetTreeParameters +) + +browsingContext.GetTreeParameters = { + ? maxDepth: js-uint, + ? root: browsingContext.BrowsingContext, +} + +browsingContext.GetTreeResult = { + contexts: browsingContext.InfoList +} + +browsingContext.HandleUserPrompt = ( + method: "browsingContext.handleUserPrompt", + params: browsingContext.HandleUserPromptParameters +) + +browsingContext.HandleUserPromptParameters = { + context: browsingContext.BrowsingContext, + ? accept: bool, + ? userText: text, +} + +browsingContext.HandleUserPromptResult = EmptyResult + +browsingContext.LocateNodes = ( + method: "browsingContext.locateNodes", + params: browsingContext.LocateNodesParameters +) + +browsingContext.LocateNodesParameters = { + context: browsingContext.BrowsingContext, + locator: browsingContext.Locator, + ? maxNodeCount: (js-uint .ge 1), + ? serializationOptions: script.SerializationOptions, + ? startNodes: [ + script.SharedReference ] +} + +browsingContext.LocateNodesResult = { + nodes: [ * script.NodeRemoteValue ] +} + +browsingContext.Navigate = ( + method: "browsingContext.navigate", + params: browsingContext.NavigateParameters +) + +browsingContext.NavigateParameters = { + context: browsingContext.BrowsingContext, + url: text, + ? wait: browsingContext.ReadinessState, +} + +browsingContext.NavigateResult = { + navigation: browsingContext.Navigation / null, + url: text, +} + +browsingContext.Print = ( + method: "browsingContext.print", + params: browsingContext.PrintParameters +) + +browsingContext.PrintParameters = { + context: browsingContext.BrowsingContext, + ? background: bool .default false, + ? margin: browsingContext.PrintMarginParameters, + ? orientation: ("portrait" / "landscape") .default "portrait", + ? page: browsingContext.PrintPageParameters, + ? pageRanges: [*(js-uint / text)], + ? scale: (0.1..2.0) .default 1.0, + ? shrinkToFit: bool .default true, +} + +browsingContext.PrintMarginParameters = { + ? bottom: (float .ge 0.0) .default 1.0, + ? left: (float .ge 0.0) .default 1.0, + ? right: (float .ge 0.0) .default 1.0, + ? top: (float .ge 0.0) .default 1.0, +} + +; Minimum size is 1pt x 1pt. Conversion follows from +; https://www.w3.org/TR/css3-values/#absolute-lengths +browsingContext.PrintPageParameters = { + ? height: (float .ge 0.0352) .default 27.94, + ? width: (float .ge 0.0352) .default 21.59, +} + +browsingContext.PrintResult = { + data: text +} + +browsingContext.Reload = ( + method: "browsingContext.reload", + params: browsingContext.ReloadParameters +) + +browsingContext.ReloadParameters = { + context: browsingContext.BrowsingContext, + ? ignoreCache: bool, + ? wait: browsingContext.ReadinessState, +} + +browsingContext.ReloadResult = browsingContext.NavigateResult + +browsingContext.SetViewport = ( + method: "browsingContext.setViewport", + params: browsingContext.SetViewportParameters +) + +browsingContext.SetViewportParameters = { + ? context: browsingContext.BrowsingContext, + ? viewport: browsingContext.Viewport / null, + ? devicePixelRatio: (float .gt 0.0) / null, + ? userContexts: [+browser.UserContext], +} + +browsingContext.Viewport = { + width: js-uint, + height: js-uint, +} + +browsingContext.SetViewportResult = EmptyResult + +browsingContext.TraverseHistory = ( + method: "browsingContext.traverseHistory", + params: browsingContext.TraverseHistoryParameters +) + +browsingContext.TraverseHistoryParameters = { + context: browsingContext.BrowsingContext, + delta: js-int, +} + +browsingContext.TraverseHistoryResult = EmptyResult + +browsingContext.ContextCreated = ( + method: "browsingContext.contextCreated", + params: browsingContext.Info +) + +browsingContext.ContextDestroyed = ( + method: "browsingContext.contextDestroyed", + params: browsingContext.Info +) + +browsingContext.NavigationStarted = ( + method: "browsingContext.navigationStarted", + params: browsingContext.NavigationInfo +) + +browsingContext.FragmentNavigated = ( + method: "browsingContext.fragmentNavigated", + params: browsingContext.NavigationInfo +) + +browsingContext.HistoryUpdated = ( + method: "browsingContext.historyUpdated", + params: browsingContext.HistoryUpdatedParameters +) + +browsingContext.HistoryUpdatedParameters = { + context: browsingContext.BrowsingContext, + timestamp: js-uint, + url: text +} + +browsingContext.DomContentLoaded = ( + method: "browsingContext.domContentLoaded", + params: browsingContext.NavigationInfo +) + +browsingContext.Load = ( + method: "browsingContext.load", + params: browsingContext.NavigationInfo +) + +browsingContext.DownloadWillBegin = ( + method: "browsingContext.downloadWillBegin", + params: browsingContext.DownloadWillBeginParams +) + +browsingContext.DownloadWillBeginParams = { + suggestedFilename: text, + browsingContext.BaseNavigationInfo +} + +browsingContext.DownloadEnd = ( + method: "browsingContext.downloadEnd", + params: browsingContext.DownloadEndParams +) + +browsingContext.DownloadEndParams = { + ( + browsingContext.DownloadCanceledParams // + browsingContext.DownloadCompleteParams + ) +} + +browsingContext.DownloadCanceledParams = ( + status: "canceled", + browsingContext.BaseNavigationInfo +) + +browsingContext.DownloadCompleteParams = ( + status: "complete", + filepath: text / null, + browsingContext.BaseNavigationInfo +) + +browsingContext.NavigationAborted = ( + method: "browsingContext.navigationAborted", + params: browsingContext.NavigationInfo +) + +browsingContext.NavigationCommitted = ( + method: "browsingContext.navigationCommitted", + params: browsingContext.NavigationInfo +) + +browsingContext.NavigationFailed = ( + method: "browsingContext.navigationFailed", + params: browsingContext.NavigationInfo +) + +browsingContext.UserPromptClosed = ( + method: "browsingContext.userPromptClosed", + params: browsingContext.UserPromptClosedParameters +) + +browsingContext.UserPromptClosedParameters = { + context: browsingContext.BrowsingContext, + accepted: bool, + type: browsingContext.UserPromptType, + ? userText: text +} + +browsingContext.UserPromptOpened = ( + method: "browsingContext.userPromptOpened", + params: browsingContext.UserPromptOpenedParameters +) + +browsingContext.UserPromptOpenedParameters = { + context: browsingContext.BrowsingContext, + handler: session.UserPromptHandlerType, + message: text, + type: browsingContext.UserPromptType, + ? defaultValue: text +} + +EmulationCommand = ( + emulation.SetForcedColorsModeThemeOverride // + emulation.SetGeolocationOverride // + emulation.SetLocaleOverride // + emulation.SetNetworkConditions // + emulation.SetScreenOrientationOverride // + emulation.SetScreenSettingsOverride // + emulation.SetScriptingEnabled // + emulation.SetScrollbarTypeOverride // + emulation.SetTimezoneOverride // + emulation.SetTouchOverride // + emulation.SetUserAgentOverride // + emulation.SetViewportMetaOverride +) + + +EmulationResult = ( + emulation.SetForcedColorsModeThemeOverrideResult / + emulation.SetGeolocationOverrideResult / + emulation.SetLocaleOverrideResult / + emulation.SetScreenOrientationOverrideResult / + emulation.SetScriptingEnabledResult / + emulation.SetScrollbarTypeOverrideResult / + emulation.SetTimezoneOverrideResult / + emulation.SetTouchOverrideResult / + emulation.SetUserAgentOverrideResult / + emulation.SetViewportMetaOverrideResult +) + +emulation.SetForcedColorsModeThemeOverride = ( + method: "emulation.setForcedColorsModeThemeOverride", + params: emulation.SetForcedColorsModeThemeOverrideParameters +) + +emulation.SetForcedColorsModeThemeOverrideParameters = { + theme: emulation.ForcedColorsModeTheme / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.ForcedColorsModeTheme = "light" / "dark" + +emulation.SetForcedColorsModeThemeOverrideResult = EmptyResult + +emulation.SetGeolocationOverride = ( + method: "emulation.setGeolocationOverride", + params: emulation.SetGeolocationOverrideParameters +) + +emulation.SetGeolocationOverrideParameters = { + ( + (coordinates: emulation.GeolocationCoordinates / null) // + (error: emulation.GeolocationPositionError) + ), + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.GeolocationCoordinates = { + latitude: -90.0..90.0, + longitude: -180.0..180.0, + ? accuracy: (float .ge 0.0) .default 1.0, + ? altitude: float / null .default null, + ? altitudeAccuracy: (float .ge 0.0) / null .default null, + ? heading: (0.0...360.0) / null .default null, + ? speed: (float .ge 0.0) / null .default null, +} + +emulation.GeolocationPositionError = { + type: "positionUnavailable" +} + +emulation.SetGeolocationOverrideResult = EmptyResult + +emulation.SetLocaleOverride = ( + method: "emulation.setLocaleOverride", + params: emulation.SetLocaleOverrideParameters +) + +emulation.SetLocaleOverrideParameters = { + locale: text / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetLocaleOverrideResult = EmptyResult + +emulation.SetNetworkConditions = ( + method: "emulation.setNetworkConditions", + params: emulation.setNetworkConditionsParameters +) + +emulation.setNetworkConditionsParameters = { + networkConditions: emulation.NetworkConditions / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.NetworkConditions = emulation.NetworkConditionsOffline + +emulation.NetworkConditionsOffline = { + type: "offline" +} + +emulation.SetNetworkConditionsResult = EmptyResult + +emulation.SetScreenSettingsOverride = ( + method: "emulation.setScreenSettingsOverride", + params: emulation.SetScreenSettingsOverrideParameters +) + +emulation.ScreenArea = { + width: js-uint, + height: js-uint +} + +emulation.SetScreenSettingsOverrideParameters = { + screenArea: emulation.ScreenArea / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetScreenSettingsOverrideResult = EmptyResult + +emulation.SetScreenOrientationOverride = ( + method: "emulation.setScreenOrientationOverride", + params: emulation.SetScreenOrientationOverrideParameters +) + +emulation.ScreenOrientationNatural = "portrait" / "landscape" +emulation.ScreenOrientationType = "portrait-primary" / "portrait-secondary" / "landscape-primary" / "landscape-secondary" + +emulation.ScreenOrientation = { + natural: emulation.ScreenOrientationNatural, + type: emulation.ScreenOrientationType +} + +emulation.SetScreenOrientationOverrideParameters = { + screenOrientation: emulation.ScreenOrientation / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetScreenOrientationOverrideResult = EmptyResult + +emulation.SetUserAgentOverride = ( + method: "emulation.setUserAgentOverride", + params: emulation.SetUserAgentOverrideParameters +) + +emulation.SetUserAgentOverrideParameters = { + userAgent: text / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetUserAgentOverrideResult = EmptyResult + +emulation.SetViewportMetaOverride = ( + method: "emulation.setViewportMetaOverride", + params: emulation.SetViewportMetaOverrideParameters +) + +emulation.SetViewportMetaOverrideParameters = { + viewportMeta: true / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetViewportMetaOverrideResult = EmptyResult + +emulation.SetScriptingEnabled = ( + method: "emulation.setScriptingEnabled", + params: emulation.SetScriptingEnabledParameters +) + +emulation.SetScriptingEnabledParameters = { + enabled: false / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetScriptingEnabledResult = EmptyResult + +emulation.SetScrollbarTypeOverride = ( + method: "emulation.setScrollbarTypeOverride", + params: emulation.SetScrollbarTypeOverrideParameters +) + +emulation.SetScrollbarTypeOverrideParameters = { + scrollbarType: "classic" / "overlay" / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetScrollbarTypeOverrideResult = EmptyResult + +emulation.SetTimezoneOverride = ( + method: "emulation.setTimezoneOverride", + params: emulation.SetTimezoneOverrideParameters +) + +emulation.SetTimezoneOverrideParameters = { + timezone: text / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetTimezoneOverrideResult = EmptyResult + +emulation.SetTouchOverride = ( + method: "emulation.setTouchOverride", + params: emulation.SetTouchOverrideParameters +) + +emulation.SetTouchOverrideParameters = { + maxTouchPoints: (js-uint .ge 1) / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetTouchOverrideResult = EmptyResult + + +NetworkCommand = ( + network.AddDataCollector // + network.AddIntercept // + network.ContinueRequest // + network.ContinueResponse // + network.ContinueWithAuth // + network.DisownData // + network.FailRequest // + network.GetData // + network.ProvideResponse // + network.RemoveDataCollector // + network.RemoveIntercept // + network.SetCacheBehavior // + network.SetExtraHeaders +) + + + +NetworkResult = ( + network.AddDataCollectorResult / + network.AddInterceptResult / + network.ContinueRequestResult / + network.ContinueResponseResult / + network.ContinueWithAuthResult / + network.DisownDataResult / + network.FailRequestResult / + network.GetDataResult / + network.ProvideResponseResult / + network.RemoveDataCollectorResult / + network.RemoveInterceptResult / + network.SetCacheBehaviorResult / + network.SetExtraHeadersResult +) + +NetworkEvent = ( + network.AuthRequired // + network.BeforeRequestSent // + network.FetchError // + network.ResponseCompleted // + network.ResponseStarted +) + + +network.AuthChallenge = { + scheme: text, + realm: text, +} + +network.AuthCredentials = { + type: "password", + username: text, + password: text, +} + +network.BaseParameters = ( + context: browsingContext.BrowsingContext / null, + isBlocked: bool, + navigation: browsingContext.Navigation / null, + redirectCount: js-uint, + request: network.RequestData, + timestamp: js-uint, + ? intercepts: [+network.Intercept] +) + +network.BytesValue = network.StringValue / network.Base64Value; + +network.StringValue = { + type: "string", + value: text, +} + +network.Base64Value = { + type: "base64", + value: text, +} + +network.Collector = text + +network.CollectorType = "blob" + + +network.SameSite = "strict" / "lax" / "none" / "default" + + +network.Cookie = { + name: text, + value: network.BytesValue, + domain: text, + path: text, + size: js-uint, + httpOnly: bool, + secure: bool, + sameSite: network.SameSite, + ? expiry: js-uint, + Extensible, +} + +network.CookieHeader = { + name: text, + value: network.BytesValue, +} + +network.DataType = "request" / "response" + +network.FetchTimingInfo = { + timeOrigin: float, + requestTime: float, + redirectStart: float, + redirectEnd: float, + fetchStart: float, + dnsStart: float, + dnsEnd: float, + connectStart: float, + connectEnd: float, + tlsStart: float, + + requestStart: float, + responseStart: float, + + responseEnd: float, +} + +network.Header = { + name: text, + value: network.BytesValue, +} + +network.Initiator = { + ? columnNumber: js-uint, + ? lineNumber: js-uint, + ? request: network.Request, + ? stackTrace: script.StackTrace, + ? type: "parser" / "script" / "preflight" / "other" +} + +network.Intercept = text + +network.Request = text; + +network.RequestData = { + request: network.Request, + url: text, + method: text, + headers: [*network.Header], + cookies: [*network.Cookie], + headersSize: js-uint, + bodySize: js-uint / null, + destination: text, + initiatorType: text / null, + timings: network.FetchTimingInfo, +} + +network.ResponseContent = { + size: js-uint +} + +network.ResponseData = { + url: text, + protocol: text, + status: js-uint, + statusText: text, + fromCache: bool, + headers: [*network.Header], + mimeType: text, + bytesReceived: js-uint, + headersSize: js-uint / null, + bodySize: js-uint / null, + content: network.ResponseContent, + ?authChallenges: [*network.AuthChallenge], +} + + +network.SetCookieHeader = { + name: text, + value: network.BytesValue, + ? domain: text, + ? httpOnly: bool, + ? expiry: text, + ? maxAge: js-int, + ? path: text, + ? sameSite: network.SameSite, + ? secure: bool, +} + +network.UrlPattern = ( + network.UrlPatternPattern / + network.UrlPatternString +) + +network.UrlPatternPattern = { + type: "pattern", + ?protocol: text, + ?hostname: text, + ?port: text, + ?pathname: text, + ?search: text, +} + + +network.UrlPatternString = { + type: "string", + pattern: text, +} + + +network.AddDataCollector = ( + method: "network.addDataCollector", + params: network.AddDataCollectorParameters +) + +network.AddDataCollectorParameters = { + dataTypes: [+network.DataType], + maxEncodedDataSize: js-uint, + ? collectorType: network.CollectorType .default "blob", + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +network.AddDataCollectorResult = { + collector: network.Collector +} + +network.AddIntercept = ( + method: "network.addIntercept", + params: network.AddInterceptParameters +) + +network.AddInterceptParameters = { + phases: [+network.InterceptPhase], + ? contexts: [+browsingContext.BrowsingContext], + ? urlPatterns: [*network.UrlPattern], +} + +network.InterceptPhase = "beforeRequestSent" / "responseStarted" / + "authRequired" + +network.AddInterceptResult = { + intercept: network.Intercept +} + +network.ContinueRequest = ( + method: "network.continueRequest", + params: network.ContinueRequestParameters +) + +network.ContinueRequestParameters = { + request: network.Request, + ?body: network.BytesValue, + ?cookies: [*network.CookieHeader], + ?headers: [*network.Header], + ?method: text, + ?url: text, +} + +network.ContinueRequestResult = EmptyResult + +network.ContinueResponse = ( + method: "network.continueResponse", + params: network.ContinueResponseParameters +) + +network.ContinueResponseParameters = { + request: network.Request, + ?cookies: [*network.SetCookieHeader] + ?credentials: network.AuthCredentials, + ?headers: [*network.Header], + ?reasonPhrase: text, + ?statusCode: js-uint, +} + +network.ContinueResponseResult = EmptyResult + +network.ContinueWithAuth = ( + method: "network.continueWithAuth", + params: network.ContinueWithAuthParameters +) + +network.ContinueWithAuthParameters = { + request: network.Request, + (network.ContinueWithAuthCredentials // network.ContinueWithAuthNoCredentials) +} + +network.ContinueWithAuthCredentials = ( + action: "provideCredentials", + credentials: network.AuthCredentials +) + +network.ContinueWithAuthNoCredentials = ( + action: "default" / "cancel" +) + +network.ContinueWithAuthResult = EmptyResult + +network.DisownData = ( + method: "network.disownData", + params: network.disownDataParameters +) + +network.disownDataParameters = { + dataType: network.DataType, + collector: network.Collector, + request: network.Request, +} + +network.DisownDataResult = EmptyResult + +network.FailRequest = ( + method: "network.failRequest", + params: network.FailRequestParameters +) + +network.FailRequestParameters = { + request: network.Request, +} + +network.FailRequestResult = EmptyResult + +network.GetData = ( + method: "network.getData", + params: network.GetDataParameters +) + +network.GetDataParameters = { + dataType: network.DataType, + ? collector: network.Collector, + ? disown: bool .default false, + request: network.Request, +} + +network.GetDataResult = { + bytes: network.BytesValue, +} + +network.ProvideResponse = ( + method: "network.provideResponse", + params: network.ProvideResponseParameters +) + +network.ProvideResponseParameters = { + request: network.Request, + ?body: network.BytesValue, + ?cookies: [*network.SetCookieHeader], + ?headers: [*network.Header], + ?reasonPhrase: text, + ?statusCode: js-uint, +} + +network.ProvideResponseResult = EmptyResult + +network.RemoveDataCollector = ( + method: "network.removeDataCollector", + params: network.RemoveDataCollectorParameters +) + +network.RemoveDataCollectorParameters = { + collector: network.Collector +} + +network.RemoveDataCollectorResult = EmptyResult + +network.RemoveIntercept = ( + method: "network.removeIntercept", + params: network.RemoveInterceptParameters +) + +network.RemoveInterceptParameters = { + intercept: network.Intercept +} + +network.RemoveInterceptResult = EmptyResult + +network.SetCacheBehavior = ( + method: "network.setCacheBehavior", + params: network.SetCacheBehaviorParameters +) + +network.SetCacheBehaviorParameters = { + cacheBehavior: "default" / "bypass", + ? contexts: [+browsingContext.BrowsingContext] +} + +network.SetCacheBehaviorResult = EmptyResult + +network.SetExtraHeaders = ( + method: "network.setExtraHeaders", + params: network.SetExtraHeadersParameters +) + +network.SetExtraHeadersParameters = { + headers: [*network.Header] + ? contexts: [+browsingContext.BrowsingContext] + ? userContexts: [+browser.UserContext] +} + +network.SetExtraHeadersResult = EmptyResult + +network.AuthRequired = ( + method: "network.authRequired", + params: network.AuthRequiredParameters +) + +network.AuthRequiredParameters = { + network.BaseParameters, + response: network.ResponseData +} + + network.BeforeRequestSent = ( + method: "network.beforeRequestSent", + params: network.BeforeRequestSentParameters + ) + +network.BeforeRequestSentParameters = { + network.BaseParameters, + ? initiator: network.Initiator, +} + + network.FetchError = ( + method: "network.fetchError", + params: network.FetchErrorParameters + ) + +network.FetchErrorParameters = { + network.BaseParameters, + errorText: text, +} + + network.ResponseCompleted = ( + method: "network.responseCompleted", + params: network.ResponseCompletedParameters + ) + +network.ResponseCompletedParameters = { + network.BaseParameters, + response: network.ResponseData, +} + + network.ResponseStarted = ( + method: "network.responseStarted", + params: network.ResponseStartedParameters + ) + +network.ResponseStartedParameters = { + network.BaseParameters, + response: network.ResponseData, +} + +ScriptCommand = ( + script.AddPreloadScript // + script.CallFunction // + script.Disown // + script.Evaluate // + script.GetRealms // + script.RemovePreloadScript +) + +ScriptResult = ( + script.AddPreloadScriptResult / + script.CallFunctionResult / + script.DisownResult / + script.EvaluateResult / + script.GetRealmsResult / + script.RemovePreloadScriptResult +) + +ScriptEvent = ( + script.Message // + script.RealmCreated // + script.RealmDestroyed +) + +script.Channel = text; + +script.ChannelValue = { + type: "channel", + value: script.ChannelProperties, +} + +script.ChannelProperties = { + channel: script.Channel, + ? serializationOptions: script.SerializationOptions, + ? ownership: script.ResultOwnership, +} + +script.EvaluateResult = ( + script.EvaluateResultSuccess / + script.EvaluateResultException +) + +script.EvaluateResultSuccess = { + type: "success", + result: script.RemoteValue, + realm: script.Realm +} + +script.EvaluateResultException = { + type: "exception", + exceptionDetails: script.ExceptionDetails + realm: script.Realm +} + +script.ExceptionDetails = { + columnNumber: js-uint, + exception: script.RemoteValue, + lineNumber: js-uint, + stackTrace: script.StackTrace, + text: text, +} + +script.Handle = text; + +script.InternalId = text; + +script.LocalValue = ( + script.RemoteReference / + script.PrimitiveProtocolValue / + script.ChannelValue / + script.ArrayLocalValue / + { script.DateLocalValue } / + script.MapLocalValue / + script.ObjectLocalValue / + { script.RegExpLocalValue } / + script.SetLocalValue +) + +script.ListLocalValue = [*script.LocalValue]; + +script.ArrayLocalValue = { + type: "array", + value: script.ListLocalValue, +} + +script.DateLocalValue = ( + type: "date", + value: text +) + +script.MappingLocalValue = [*[(script.LocalValue / text), script.LocalValue]]; + +script.MapLocalValue = { + type: "map", + value: script.MappingLocalValue, +} + +script.ObjectLocalValue = { + type: "object", + value: script.MappingLocalValue, +} + +script.RegExpValue = { + pattern: text, + ? flags: text, +} + +script.RegExpLocalValue = ( + type: "regexp", + value: script.RegExpValue, +) + +script.SetLocalValue = { + type: "set", + value: script.ListLocalValue, +} + +script.PreloadScript = text; + +script.Realm = text; + +script.PrimitiveProtocolValue = ( + script.UndefinedValue / + script.NullValue / + script.StringValue / + script.NumberValue / + script.BooleanValue / + script.BigIntValue +) + +script.UndefinedValue = { + type: "undefined", +} + +script.NullValue = { + type: "null", +} + +script.StringValue = { + type: "string", + value: text, +} + +script.SpecialNumber = "NaN" / "-0" / "Infinity" / "-Infinity"; + +script.NumberValue = { + type: "number", + value: number / script.SpecialNumber, +} + +script.BooleanValue = { + type: "boolean", + value: bool, +} + +script.BigIntValue = { + type: "bigint", + value: text, +} + +script.RealmInfo = ( + script.WindowRealmInfo / + script.DedicatedWorkerRealmInfo / + script.SharedWorkerRealmInfo / + script.ServiceWorkerRealmInfo / + script.WorkerRealmInfo / + script.PaintWorkletRealmInfo / + script.AudioWorkletRealmInfo / + script.WorkletRealmInfo +) + +script.BaseRealmInfo = ( + realm: script.Realm, + origin: text +) + +script.WindowRealmInfo = { + script.BaseRealmInfo, + type: "window", + context: browsingContext.BrowsingContext, + ? sandbox: text +} + +script.DedicatedWorkerRealmInfo = { + script.BaseRealmInfo, + type: "dedicated-worker", + owners: [script.Realm] +} + +script.SharedWorkerRealmInfo = { + script.BaseRealmInfo, + type: "shared-worker" +} + +script.ServiceWorkerRealmInfo = { + script.BaseRealmInfo, + type: "service-worker" +} + +script.WorkerRealmInfo = { + script.BaseRealmInfo, + type: "worker" +} + +script.PaintWorkletRealmInfo = { + script.BaseRealmInfo, + type: "paint-worklet" +} + +script.AudioWorkletRealmInfo = { + script.BaseRealmInfo, + type: "audio-worklet" +} + +script.WorkletRealmInfo = { + script.BaseRealmInfo, + type: "worklet" +} + +script.RealmType = "window" / "dedicated-worker" / "shared-worker" / "service-worker" / + "worker" / "paint-worklet" / "audio-worklet" / "worklet" + + + +script.RemoteReference = ( + script.SharedReference / + script.RemoteObjectReference +) + +script.SharedReference = { + sharedId: script.SharedId + + ? handle: script.Handle, + Extensible +} + +script.RemoteObjectReference = { + handle: script.Handle, + + ? sharedId: script.SharedId + Extensible +} + +script.RemoteValue = ( + script.PrimitiveProtocolValue / + script.SymbolRemoteValue / + script.ArrayRemoteValue / + script.ObjectRemoteValue / + script.FunctionRemoteValue / + script.RegExpRemoteValue / + script.DateRemoteValue / + script.MapRemoteValue / + script.SetRemoteValue / + script.WeakMapRemoteValue / + script.WeakSetRemoteValue / + script.GeneratorRemoteValue / + script.ErrorRemoteValue / + script.ProxyRemoteValue / + script.PromiseRemoteValue / + script.TypedArrayRemoteValue / + script.ArrayBufferRemoteValue / + script.NodeListRemoteValue / + script.HTMLCollectionRemoteValue / + script.NodeRemoteValue / + script.WindowProxyRemoteValue +) + +script.ListRemoteValue = [*script.RemoteValue]; + +script.MappingRemoteValue = [*[(script.RemoteValue / text), script.RemoteValue]]; + +script.SymbolRemoteValue = { + type: "symbol", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ArrayRemoteValue = { + type: "array", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.ObjectRemoteValue = { + type: "object", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.MappingRemoteValue, +} + +script.FunctionRemoteValue = { + type: "function", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.RegExpRemoteValue = { + script.RegExpLocalValue, + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.DateRemoteValue = { + script.DateLocalValue, + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.MapRemoteValue = { + type: "map", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.MappingRemoteValue, +} + +script.SetRemoteValue = { + type: "set", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue +} + +script.WeakMapRemoteValue = { + type: "weakmap", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.WeakSetRemoteValue = { + type: "weakset", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.GeneratorRemoteValue = { + type: "generator", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ErrorRemoteValue = { + type: "error", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ProxyRemoteValue = { + type: "proxy", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.PromiseRemoteValue = { + type: "promise", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.TypedArrayRemoteValue = { + type: "typedarray", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ArrayBufferRemoteValue = { + type: "arraybuffer", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.NodeListRemoteValue = { + type: "nodelist", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.HTMLCollectionRemoteValue = { + type: "htmlcollection", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.NodeRemoteValue = { + type: "node", + ? sharedId: script.SharedId, + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.NodeProperties, +} + +script.NodeProperties = { + nodeType: js-uint, + childNodeCount: js-uint, + ? attributes: {*text => text}, + ? children: [*script.NodeRemoteValue], + ? localName: text, + ? mode: "open" / "closed", + ? namespaceURI: text, + ? nodeValue: text, + ? shadowRoot: script.NodeRemoteValue / null, +} + +script.WindowProxyRemoteValue = { + type: "window", + value: script.WindowProxyProperties, + ? handle: script.Handle, + ? internalId: script.InternalId +} + +script.WindowProxyProperties = { + context: browsingContext.BrowsingContext +} + +script.ResultOwnership = "root" / "none" + +script.SerializationOptions = { + ? maxDomDepth: (js-uint / null) .default 0, + ? maxObjectDepth: (js-uint / null) .default null, + ? includeShadowTree: ("none" / "open" / "all") .default "none", +} + +script.SharedId = text; + +script.StackFrame = { + columnNumber: js-uint, + functionName: text, + lineNumber: js-uint, + url: text, +} + +script.StackTrace = { + callFrames: [*script.StackFrame], +} + +script.Source = { + realm: script.Realm, + ? context: browsingContext.BrowsingContext +} + +script.RealmTarget = { + realm: script.Realm +} + +script.ContextTarget = { + context: browsingContext.BrowsingContext, + ? sandbox: text +} + +script.Target = ( + script.ContextTarget / + script.RealmTarget +) + +script.AddPreloadScript = ( + method: "script.addPreloadScript", + params: script.AddPreloadScriptParameters +) + +script.AddPreloadScriptParameters = { + functionDeclaration: text, + ? arguments: [*script.ChannelValue], + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], + ? sandbox: text +} + +script.AddPreloadScriptResult = { + script: script.PreloadScript +} + +script.Disown = ( + method: "script.disown", + params: script.DisownParameters +) + +script.DisownParameters = { + handles: [*script.Handle] + target: script.Target; +} + +script.DisownResult = EmptyResult + +script.CallFunction = ( + method: "script.callFunction", + params: script.CallFunctionParameters +) + +script.CallFunctionParameters = { + functionDeclaration: text, + awaitPromise: bool, + target: script.Target, + ? arguments: [*script.LocalValue], + ? resultOwnership: script.ResultOwnership, + ? serializationOptions: script.SerializationOptions, + ? this: script.LocalValue, + ? userActivation: bool .default false, +} + +script.CallFunctionResult = script.EvaluateResult + +script.Evaluate = ( + method: "script.evaluate", + params: script.EvaluateParameters +) + +script.EvaluateParameters = { + expression: text, + target: script.Target, + awaitPromise: bool, + ? resultOwnership: script.ResultOwnership, + ? serializationOptions: script.SerializationOptions, + ? userActivation: bool .default false, +} + +script.GetRealms = ( + method: "script.getRealms", + params: script.GetRealmsParameters +) + +script.GetRealmsParameters = { + ? context: browsingContext.BrowsingContext, + ? type: script.RealmType, +} + +script.GetRealmsResult = { + realms: [*script.RealmInfo] +} + +script.RemovePreloadScript = ( + method: "script.removePreloadScript", + params: script.RemovePreloadScriptParameters +) + +script.RemovePreloadScriptParameters = { + script: script.PreloadScript +} + +script.RemovePreloadScriptResult = EmptyResult + + script.Message = ( + method: "script.message", + params: script.MessageParameters + ) + +script.MessageParameters = { + channel: script.Channel, + data: script.RemoteValue, + source: script.Source, +} + +script.RealmCreated = ( + method: "script.realmCreated", + params: script.RealmInfo +) + +script.RealmDestroyed = ( + method: "script.realmDestroyed", + params: script.RealmDestroyedParameters +) + +script.RealmDestroyedParameters = { + realm: script.Realm +} + + +StorageCommand = ( + storage.DeleteCookies // + storage.GetCookies // + storage.SetCookie +) + +StorageResult = ( + storage.DeleteCookiesResult / + storage.GetCookiesResult / + storage.SetCookieResult +) + +storage.PartitionKey = { + ? userContext: text, + ? sourceOrigin: text, + Extensible, +} + +storage.GetCookies = ( + method: "storage.getCookies", + params: storage.GetCookiesParameters +) + + +storage.CookieFilter = { + ? name: text, + ? value: network.BytesValue, + ? domain: text, + ? path: text, + ? size: js-uint, + ? httpOnly: bool, + ? secure: bool, + ? sameSite: network.SameSite, + ? expiry: js-uint, + Extensible, +} + +storage.BrowsingContextPartitionDescriptor = { + type: "context", + context: browsingContext.BrowsingContext +} + +storage.StorageKeyPartitionDescriptor = { + type: "storageKey", + ? userContext: text, + ? sourceOrigin: text, + Extensible, +} + +storage.PartitionDescriptor = ( + storage.BrowsingContextPartitionDescriptor / + storage.StorageKeyPartitionDescriptor +) + +storage.GetCookiesParameters = { + ? filter: storage.CookieFilter, + ? partition: storage.PartitionDescriptor, +} + +storage.GetCookiesResult = { + cookies: [*network.Cookie], + partitionKey: storage.PartitionKey, +} + +storage.SetCookie = ( + method: "storage.setCookie", + params: storage.SetCookieParameters, +) + + +storage.PartialCookie = { + name: text, + value: network.BytesValue, + domain: text, + ? path: text, + ? httpOnly: bool, + ? secure: bool, + ? sameSite: network.SameSite, + ? expiry: js-uint, + Extensible, +} + +storage.SetCookieParameters = { + cookie: storage.PartialCookie, + ? partition: storage.PartitionDescriptor, +} + +storage.SetCookieResult = { + partitionKey: storage.PartitionKey +} + +storage.DeleteCookies = ( + method: "storage.deleteCookies", + params: storage.DeleteCookiesParameters, +) + +storage.DeleteCookiesParameters = { + ? filter: storage.CookieFilter, + ? partition: storage.PartitionDescriptor, +} + +storage.DeleteCookiesResult = { + partitionKey: storage.PartitionKey +} + +LogEvent = ( + log.EntryAdded +) + +log.Level = "debug" / "info" / "warn" / "error" + +log.Entry = ( + log.GenericLogEntry / + log.ConsoleLogEntry / + log.JavascriptLogEntry +) + +log.BaseLogEntry = ( + level: log.Level, + source: script.Source, + text: text / null, + timestamp: js-uint, + ? stackTrace: script.StackTrace, +) + +log.GenericLogEntry = { + log.BaseLogEntry, + type: text, +} + +log.ConsoleLogEntry = { + log.BaseLogEntry, + type: "console", + method: text, + args: [*script.RemoteValue], +} + +log.JavascriptLogEntry = { + log.BaseLogEntry, + type: "javascript", +} + +log.EntryAdded = ( + method: "log.entryAdded", + params: log.Entry, +) + +InputCommand = ( + input.PerformActions // + input.ReleaseActions // + input.SetFiles +) + +InputResult = ( + input.PerformActionsResult / + input.ReleaseActionsResult / + input.SetFilesResult +) + + +InputEvent = ( + input.FileDialogOpened +) + +input.ElementOrigin = { + type: "element", + element: script.SharedReference +} + +input.PerformActions = ( + method: "input.performActions", + params: input.PerformActionsParameters +) + +input.PerformActionsParameters = { + context: browsingContext.BrowsingContext, + actions: [*input.SourceActions] +} + +input.SourceActions = ( + input.NoneSourceActions / + input.KeySourceActions / + input.PointerSourceActions / + input.WheelSourceActions +) + +input.NoneSourceActions = { + type: "none", + id: text, + actions: [*input.NoneSourceAction] +} + +input.NoneSourceAction = input.PauseAction + +input.KeySourceActions = { + type: "key", + id: text, + actions: [*input.KeySourceAction] +} + +input.KeySourceAction = ( + input.PauseAction / + input.KeyDownAction / + input.KeyUpAction +) + +input.PointerSourceActions = { + type: "pointer", + id: text, + ? parameters: input.PointerParameters, + actions: [*input.PointerSourceAction] +} + +input.PointerType = "mouse" / "pen" / "touch" + +input.PointerParameters = { + ? pointerType: input.PointerType .default "mouse" +} + +input.PointerSourceAction = ( + input.PauseAction / + input.PointerDownAction / + input.PointerUpAction / + input.PointerMoveAction +) + +input.WheelSourceActions = { + type: "wheel", + id: text, + actions: [*input.WheelSourceAction] +} + +input.WheelSourceAction = ( + input.PauseAction / + input.WheelScrollAction +) + +input.PauseAction = { + type: "pause", + ? duration: js-uint +} + +input.KeyDownAction = { + type: "keyDown", + value: text +} + +input.KeyUpAction = { + type: "keyUp", + value: text +} + +input.PointerUpAction = { + type: "pointerUp", + button: js-uint, +} + +input.PointerDownAction = { + type: "pointerDown", + button: js-uint, + input.PointerCommonProperties +} + +input.PointerMoveAction = { + type: "pointerMove", + x: float, + y: float, + ? duration: js-uint, + ? origin: input.Origin, + input.PointerCommonProperties +} + +input.WheelScrollAction = { + type: "scroll", + x: js-int, + y: js-int, + deltaX: js-int, + deltaY: js-int, + ? duration: js-uint, + ? origin: input.Origin .default "viewport", +} + +input.PointerCommonProperties = ( + ? width: js-uint .default 1, + ? height: js-uint .default 1, + ? pressure: float .default 0.0, + ? tangentialPressure: float .default 0.0, + ? twist: (0..359) .default 0, + ; 0 .. Math.PI / 2 + ? altitudeAngle: (0.0..1.5707963267948966) .default 0.0, + ; 0 .. 2 * Math.PI + ? azimuthAngle: (0.0..6.283185307179586) .default 0.0, +) + +input.Origin = "viewport" / "pointer" / input.ElementOrigin + +input.PerformActionsResult = EmptyResult + +input.ReleaseActions = ( + method: "input.releaseActions", + params: input.ReleaseActionsParameters +) + +input.ReleaseActionsParameters = { + context: browsingContext.BrowsingContext, +} + +input.ReleaseActionsResult = EmptyResult + +input.SetFiles = ( + method: "input.setFiles", + params: input.SetFilesParameters +) + +input.SetFilesParameters = { + context: browsingContext.BrowsingContext, + element: script.SharedReference, + files: [*text] +} + +input.SetFilesResult = EmptyResult + +input.FileDialogOpened = ( + method: "input.fileDialogOpened", + params: input.FileDialogInfo +) + +input.FileDialogInfo = { + context: browsingContext.BrowsingContext, + ? element: script.SharedReference, + multiple: bool, +} + +WebExtensionCommand = ( + webExtension.Install // + webExtension.Uninstall +) + +WebExtensionResult = ( + webExtension.InstallResult / + webExtension.UninstallResult +) + +webExtension.Extension = text + +webExtension.Install = ( + method: "webExtension.install", + params: webExtension.InstallParameters +) + +webExtension.InstallParameters = { + extensionData: webExtension.ExtensionData, +} + +webExtension.ExtensionData = ( + webExtension.ExtensionArchivePath / + webExtension.ExtensionBase64Encoded / + webExtension.ExtensionPath +) + +webExtension.ExtensionPath = { + type: "path", + path: text, +} + +webExtension.ExtensionArchivePath = { + type: "archivePath", + path: text, +} + +webExtension.ExtensionBase64Encoded = { + type: "base64", + value: text, +} + +webExtension.InstallResult = { + extension: webExtension.Extension +} + +webExtension.Uninstall = ( + method: "webExtension.uninstall", + params: webExtension.UninstallParameters +) + +webExtension.UninstallParameters = { + extension: webExtension.Extension, +} + +webExtension.UninstallResult = EmptyResult diff --git a/common/bidi/spec/local.cddl b/common/bidi/spec/local.cddl new file mode 100644 index 0000000000000..d43af0ae11b03 --- /dev/null +++ b/common/bidi/spec/local.cddl @@ -0,0 +1,1331 @@ +Message = ( + CommandResponse / + ErrorResponse / + Event +) + +CommandResponse = { + type: "success", + id: js-uint, + result: ResultData, + Extensible +} + +ErrorResponse = { + type: "error", + id: js-uint / null, + error: ErrorCode, + message: text, + ? stacktrace: text, + Extensible +} + +ResultData = ( + BrowserResult / + BrowsingContextResult / + EmulationResult / + InputResult / + NetworkResult / + ScriptResult / + SessionResult / + StorageResult / + WebExtensionResult +) + +EmptyResult = { + Extensible +} + +Event = { + type: "event", + EventData, + Extensible +} + +EventData = ( + BrowsingContextEvent // + InputEvent // + LogEvent // + NetworkEvent // + ScriptEvent +) + +Extensible = (*text => any) + +js-int = -9007199254740991..9007199254740991 +js-uint = 0..9007199254740991 + +ErrorCode = "invalid argument" / + "invalid selector" / + "invalid session id" / + "invalid web extension" / + "move target out of bounds" / + "no such alert" / + "no such network collector" / + "no such element" / + "no such frame" / + "no such handle" / + "no such history entry" / + "no such intercept" / + "no such network data" / + "no such node" / + "no such request" / + "no such script" / + "no such storage partition" / + "no such user context" / + "no such web extension" / + "session not created" / + "unable to capture screen" / + "unable to close browser" / + "unable to set cookie" / + "unable to set file input" / + "unavailable network data" / + "underspecified storage partition" / + "unknown command" / + "unknown error" / + "unsupported operation" + +SessionResult = ( + session.EndResult / + session.NewResult / + session.StatusResult / + session.SubscribeResult / + session.UnsubscribeResult +) + +session.CapabilitiesRequest = { + ? alwaysMatch: session.CapabilityRequest, + ? firstMatch: [*session.CapabilityRequest] +} + +session.CapabilityRequest = { + ? acceptInsecureCerts: bool, + ? browserName: text, + ? browserVersion: text, + ? platformName: text, + ? proxy: session.ProxyConfiguration, + ? unhandledPromptBehavior: session.UserPromptHandler, + Extensible +} + +session.ProxyConfiguration = { + session.AutodetectProxyConfiguration // + session.DirectProxyConfiguration // + session.ManualProxyConfiguration // + session.PacProxyConfiguration // + session.SystemProxyConfiguration +} + +session.AutodetectProxyConfiguration = ( + proxyType: "autodetect", + Extensible +) + +session.DirectProxyConfiguration = ( + proxyType: "direct", + Extensible +) + +session.ManualProxyConfiguration = ( + proxyType: "manual", + ? httpProxy: text, + ? sslProxy: text, + ? session.SocksProxyConfiguration, + ? noProxy: [*text], + Extensible +) + +session.SocksProxyConfiguration = ( + socksProxy: text, + socksVersion: 0..255, +) + +session.PacProxyConfiguration = ( + proxyType: "pac", + proxyAutoconfigUrl: text, + Extensible +) + +session.SystemProxyConfiguration = ( + proxyType: "system", + Extensible +) + + +session.UserPromptHandler = { + ? alert: session.UserPromptHandlerType, + ? beforeUnload: session.UserPromptHandlerType, + ? confirm: session.UserPromptHandlerType, + ? default: session.UserPromptHandlerType, + ? file: session.UserPromptHandlerType, + ? prompt: session.UserPromptHandlerType, +} + +session.UserPromptHandlerType = "accept" / "dismiss" / "ignore"; + +session.Subscription = text + +session.StatusResult = { + ready: bool, + message: text, +} + +session.NewResult = { + sessionId: text, + capabilities: { + acceptInsecureCerts: bool, + browserName: text, + browserVersion: text, + platformName: text, + setWindowRect: bool, + userAgent: text, + ? proxy: session.ProxyConfiguration, + ? unhandledPromptBehavior: session.UserPromptHandler, + ? webSocketUrl: text, + Extensible + } +} + +session.EndResult = EmptyResult + +session.SubscribeResult = { + subscription: session.Subscription, +} + +session.UnsubscribeResult = EmptyResult + +BrowserResult = ( + browser.CloseResult / + browser.CreateUserContextResult / + browser.GetClientWindowsResult / + browser.GetUserContextsResult / + browser.RemoveUserContextResult / + browser.SetClientWindowStateResult / + browser.SetDownloadBehaviorResult +) + +browser.ClientWindow = text; + +browser.ClientWindowInfo = { + active: bool, + clientWindow: browser.ClientWindow, + height: js-uint, + state: "fullscreen" / "maximized" / "minimized" / "normal", + width: js-uint, + x: js-int, + y: js-int, +} + +browser.UserContext = text; + +browser.UserContextInfo = { + userContext: browser.UserContext +} + +browser.CloseResult = EmptyResult + +browser.CreateUserContextResult = browser.UserContextInfo + +browser.GetClientWindowsResult = { + clientWindows: [ * browser.ClientWindowInfo] +} + +browser.GetUserContextsResult = { + userContexts: [ + browser.UserContextInfo] +} + +browser.RemoveUserContextResult = EmptyResult + +browser.SetClientWindowStateResult = browser.ClientWindowInfo + +browser.SetDownloadBehaviorResult = EmptyResult + +BrowsingContextResult = ( + browsingContext.ActivateResult / + browsingContext.CaptureScreenshotResult / + browsingContext.CloseResult / + browsingContext.CreateResult / + browsingContext.GetTreeResult / + browsingContext.HandleUserPromptResult / + browsingContext.LocateNodesResult / + browsingContext.NavigateResult / + browsingContext.PrintResult / + browsingContext.ReloadResult / + browsingContext.SetViewportResult / + browsingContext.TraverseHistoryResult +) + +BrowsingContextEvent = ( + browsingContext.ContextCreated // + browsingContext.ContextDestroyed // + browsingContext.DomContentLoaded // + browsingContext.DownloadEnd // + browsingContext.DownloadWillBegin // + browsingContext.FragmentNavigated // + browsingContext.HistoryUpdated // + browsingContext.Load // + browsingContext.NavigationAborted // + browsingContext.NavigationCommitted // + browsingContext.NavigationFailed // + browsingContext.NavigationStarted // + browsingContext.UserPromptClosed // + browsingContext.UserPromptOpened +) + +browsingContext.BrowsingContext = text; + +browsingContext.InfoList = [*browsingContext.Info] + +browsingContext.Info = { + children: browsingContext.InfoList / null, + clientWindow: browser.ClientWindow, + context: browsingContext.BrowsingContext, + originalOpener: browsingContext.BrowsingContext / null, + url: text, + userContext: browser.UserContext, + ? parent: browsingContext.BrowsingContext / null, +} + +browsingContext.Locator = ( + browsingContext.AccessibilityLocator / + browsingContext.CssLocator / + browsingContext.ContextLocator / + browsingContext.InnerTextLocator / + browsingContext.XPathLocator +) + +browsingContext.AccessibilityLocator = { + type: "accessibility", + value: { + ? name: text, + ? role: text, + } +} + +browsingContext.CssLocator = { + type: "css", + value: text +} + +browsingContext.ContextLocator = { + type: "context", + value: { + context: browsingContext.BrowsingContext, + } +} + +browsingContext.InnerTextLocator = { + type: "innerText", + value: text, + ? ignoreCase: bool + ? matchType: "full" / "partial", + ? maxDepth: js-uint, +} + +browsingContext.XPathLocator = { + type: "xpath", + value: text +} + +browsingContext.Navigation = text; + +browsingContext.BaseNavigationInfo = ( + context: browsingContext.BrowsingContext, + navigation: browsingContext.Navigation / null, + timestamp: js-uint, + url: text, +) + +browsingContext.NavigationInfo = { + browsingContext.BaseNavigationInfo +} + +browsingContext.UserPromptType = "alert" / "beforeunload" / "confirm" / "prompt"; + +browsingContext.ActivateResult = EmptyResult + +browsingContext.CaptureScreenshotResult = { + data: text +} + +browsingContext.CloseResult = EmptyResult + +browsingContext.CreateResult = { + context: browsingContext.BrowsingContext +} + +browsingContext.GetTreeResult = { + contexts: browsingContext.InfoList +} + +browsingContext.HandleUserPromptResult = EmptyResult + +browsingContext.LocateNodesResult = { + nodes: [ * script.NodeRemoteValue ] +} + +browsingContext.NavigateResult = { + navigation: browsingContext.Navigation / null, + url: text, +} + +browsingContext.PrintResult = { + data: text +} + +browsingContext.ReloadResult = browsingContext.NavigateResult + +browsingContext.SetViewportResult = EmptyResult + +browsingContext.TraverseHistoryResult = EmptyResult + +browsingContext.ContextCreated = ( + method: "browsingContext.contextCreated", + params: browsingContext.Info +) + +browsingContext.ContextDestroyed = ( + method: "browsingContext.contextDestroyed", + params: browsingContext.Info +) + +browsingContext.NavigationStarted = ( + method: "browsingContext.navigationStarted", + params: browsingContext.NavigationInfo +) + +browsingContext.FragmentNavigated = ( + method: "browsingContext.fragmentNavigated", + params: browsingContext.NavigationInfo +) + +browsingContext.HistoryUpdated = ( + method: "browsingContext.historyUpdated", + params: browsingContext.HistoryUpdatedParameters +) + +browsingContext.HistoryUpdatedParameters = { + context: browsingContext.BrowsingContext, + timestamp: js-uint, + url: text +} + +browsingContext.DomContentLoaded = ( + method: "browsingContext.domContentLoaded", + params: browsingContext.NavigationInfo +) + +browsingContext.Load = ( + method: "browsingContext.load", + params: browsingContext.NavigationInfo +) + +browsingContext.DownloadWillBegin = ( + method: "browsingContext.downloadWillBegin", + params: browsingContext.DownloadWillBeginParams +) + +browsingContext.DownloadWillBeginParams = { + suggestedFilename: text, + browsingContext.BaseNavigationInfo +} + +browsingContext.DownloadEnd = ( + method: "browsingContext.downloadEnd", + params: browsingContext.DownloadEndParams +) + +browsingContext.DownloadEndParams = { + ( + browsingContext.DownloadCanceledParams // + browsingContext.DownloadCompleteParams + ) +} + +browsingContext.DownloadCanceledParams = ( + status: "canceled", + browsingContext.BaseNavigationInfo +) + +browsingContext.DownloadCompleteParams = ( + status: "complete", + filepath: text / null, + browsingContext.BaseNavigationInfo +) + +browsingContext.NavigationAborted = ( + method: "browsingContext.navigationAborted", + params: browsingContext.NavigationInfo +) + +browsingContext.NavigationCommitted = ( + method: "browsingContext.navigationCommitted", + params: browsingContext.NavigationInfo +) + +browsingContext.NavigationFailed = ( + method: "browsingContext.navigationFailed", + params: browsingContext.NavigationInfo +) + +browsingContext.UserPromptClosed = ( + method: "browsingContext.userPromptClosed", + params: browsingContext.UserPromptClosedParameters +) + +browsingContext.UserPromptClosedParameters = { + context: browsingContext.BrowsingContext, + accepted: bool, + type: browsingContext.UserPromptType, + ? userText: text +} + +browsingContext.UserPromptOpened = ( + method: "browsingContext.userPromptOpened", + params: browsingContext.UserPromptOpenedParameters +) + +browsingContext.UserPromptOpenedParameters = { + context: browsingContext.BrowsingContext, + handler: session.UserPromptHandlerType, + message: text, + type: browsingContext.UserPromptType, + ? defaultValue: text +} + +EmulationResult = ( + emulation.SetForcedColorsModeThemeOverrideResult / + emulation.SetGeolocationOverrideResult / + emulation.SetLocaleOverrideResult / + emulation.SetScreenOrientationOverrideResult / + emulation.SetScriptingEnabledResult / + emulation.SetScrollbarTypeOverrideResult / + emulation.SetTimezoneOverrideResult / + emulation.SetTouchOverrideResult / + emulation.SetUserAgentOverrideResult / + emulation.SetViewportMetaOverrideResult +) + +emulation.SetForcedColorsModeThemeOverrideResult = EmptyResult + +emulation.SetGeolocationOverrideResult = EmptyResult + +emulation.SetLocaleOverrideResult = EmptyResult + +emulation.SetNetworkConditionsResult = EmptyResult + +emulation.SetScreenSettingsOverrideResult = EmptyResult + +emulation.SetScreenOrientationOverrideResult = EmptyResult + +emulation.SetUserAgentOverrideResult = EmptyResult + +emulation.SetViewportMetaOverrideResult = EmptyResult + +emulation.SetScriptingEnabledResult = EmptyResult + +emulation.SetScrollbarTypeOverrideResult = EmptyResult + +emulation.SetTimezoneOverrideResult = EmptyResult + +emulation.SetTouchOverrideResult = EmptyResult + + +NetworkResult = ( + network.AddDataCollectorResult / + network.AddInterceptResult / + network.ContinueRequestResult / + network.ContinueResponseResult / + network.ContinueWithAuthResult / + network.DisownDataResult / + network.FailRequestResult / + network.GetDataResult / + network.ProvideResponseResult / + network.RemoveDataCollectorResult / + network.RemoveInterceptResult / + network.SetCacheBehaviorResult / + network.SetExtraHeadersResult +) + +NetworkEvent = ( + network.AuthRequired // + network.BeforeRequestSent // + network.FetchError // + network.ResponseCompleted // + network.ResponseStarted +) + + +network.AuthChallenge = { + scheme: text, + realm: text, +} + +network.BaseParameters = ( + context: browsingContext.BrowsingContext / null, + isBlocked: bool, + navigation: browsingContext.Navigation / null, + redirectCount: js-uint, + request: network.RequestData, + timestamp: js-uint, + ? intercepts: [+network.Intercept] +) + +network.BytesValue = network.StringValue / network.Base64Value; + +network.StringValue = { + type: "string", + value: text, +} + +network.Base64Value = { + type: "base64", + value: text, +} + +network.Collector = text + +network.CollectorType = "blob" + + +network.SameSite = "strict" / "lax" / "none" / "default" + + +network.Cookie = { + name: text, + value: network.BytesValue, + domain: text, + path: text, + size: js-uint, + httpOnly: bool, + secure: bool, + sameSite: network.SameSite, + ? expiry: js-uint, + Extensible, +} + +network.DataType = "request" / "response" + +network.FetchTimingInfo = { + timeOrigin: float, + requestTime: float, + redirectStart: float, + redirectEnd: float, + fetchStart: float, + dnsStart: float, + dnsEnd: float, + connectStart: float, + connectEnd: float, + tlsStart: float, + + requestStart: float, + responseStart: float, + + responseEnd: float, +} + +network.Header = { + name: text, + value: network.BytesValue, +} + +network.Initiator = { + ? columnNumber: js-uint, + ? lineNumber: js-uint, + ? request: network.Request, + ? stackTrace: script.StackTrace, + ? type: "parser" / "script" / "preflight" / "other" +} + +network.Intercept = text + +network.Request = text; + +network.RequestData = { + request: network.Request, + url: text, + method: text, + headers: [*network.Header], + cookies: [*network.Cookie], + headersSize: js-uint, + bodySize: js-uint / null, + destination: text, + initiatorType: text / null, + timings: network.FetchTimingInfo, +} + +network.ResponseContent = { + size: js-uint +} + +network.ResponseData = { + url: text, + protocol: text, + status: js-uint, + statusText: text, + fromCache: bool, + headers: [*network.Header], + mimeType: text, + bytesReceived: js-uint, + headersSize: js-uint / null, + bodySize: js-uint / null, + content: network.ResponseContent, + ?authChallenges: [*network.AuthChallenge], +} + +network.AddDataCollectorResult = { + collector: network.Collector +} + +network.AddInterceptResult = { + intercept: network.Intercept +} + +network.ContinueRequestResult = EmptyResult + +network.ContinueResponseResult = EmptyResult + +network.ContinueWithAuthResult = EmptyResult + +network.DisownDataResult = EmptyResult + +network.FailRequestResult = EmptyResult + +network.GetDataResult = { + bytes: network.BytesValue, +} + +network.ProvideResponseResult = EmptyResult + +network.RemoveDataCollectorResult = EmptyResult + +network.RemoveInterceptResult = EmptyResult + +network.SetCacheBehaviorResult = EmptyResult + +network.SetExtraHeadersResult = EmptyResult + +network.AuthRequired = ( + method: "network.authRequired", + params: network.AuthRequiredParameters +) + +network.AuthRequiredParameters = { + network.BaseParameters, + response: network.ResponseData +} + + network.BeforeRequestSent = ( + method: "network.beforeRequestSent", + params: network.BeforeRequestSentParameters + ) + +network.BeforeRequestSentParameters = { + network.BaseParameters, + ? initiator: network.Initiator, +} + + network.FetchError = ( + method: "network.fetchError", + params: network.FetchErrorParameters + ) + +network.FetchErrorParameters = { + network.BaseParameters, + errorText: text, +} + + network.ResponseCompleted = ( + method: "network.responseCompleted", + params: network.ResponseCompletedParameters + ) + +network.ResponseCompletedParameters = { + network.BaseParameters, + response: network.ResponseData, +} + + network.ResponseStarted = ( + method: "network.responseStarted", + params: network.ResponseStartedParameters + ) + +network.ResponseStartedParameters = { + network.BaseParameters, + response: network.ResponseData, +} + +ScriptResult = ( + script.AddPreloadScriptResult / + script.CallFunctionResult / + script.DisownResult / + script.EvaluateResult / + script.GetRealmsResult / + script.RemovePreloadScriptResult +) + +ScriptEvent = ( + script.Message // + script.RealmCreated // + script.RealmDestroyed +) + +script.Channel = text; + +script.ChannelValue = { + type: "channel", + value: script.ChannelProperties, +} + +script.ChannelProperties = { + channel: script.Channel, + ? serializationOptions: script.SerializationOptions, + ? ownership: script.ResultOwnership, +} + +script.EvaluateResult = ( + script.EvaluateResultSuccess / + script.EvaluateResultException +) + +script.EvaluateResultSuccess = { + type: "success", + result: script.RemoteValue, + realm: script.Realm +} + +script.EvaluateResultException = { + type: "exception", + exceptionDetails: script.ExceptionDetails + realm: script.Realm +} + +script.ExceptionDetails = { + columnNumber: js-uint, + exception: script.RemoteValue, + lineNumber: js-uint, + stackTrace: script.StackTrace, + text: text, +} + +script.Handle = text; + +script.InternalId = text; + +script.LocalValue = ( + script.RemoteReference / + script.PrimitiveProtocolValue / + script.ChannelValue / + script.ArrayLocalValue / + { script.DateLocalValue } / + script.MapLocalValue / + script.ObjectLocalValue / + { script.RegExpLocalValue } / + script.SetLocalValue +) + +script.ListLocalValue = [*script.LocalValue]; + +script.ArrayLocalValue = { + type: "array", + value: script.ListLocalValue, +} + +script.DateLocalValue = ( + type: "date", + value: text +) + +script.MappingLocalValue = [*[(script.LocalValue / text), script.LocalValue]]; + +script.MapLocalValue = { + type: "map", + value: script.MappingLocalValue, +} + +script.ObjectLocalValue = { + type: "object", + value: script.MappingLocalValue, +} + +script.RegExpValue = { + pattern: text, + ? flags: text, +} + +script.RegExpLocalValue = ( + type: "regexp", + value: script.RegExpValue, +) + +script.SetLocalValue = { + type: "set", + value: script.ListLocalValue, +} + +script.PreloadScript = text; + +script.Realm = text; + +script.PrimitiveProtocolValue = ( + script.UndefinedValue / + script.NullValue / + script.StringValue / + script.NumberValue / + script.BooleanValue / + script.BigIntValue +) + +script.UndefinedValue = { + type: "undefined", +} + +script.NullValue = { + type: "null", +} + +script.StringValue = { + type: "string", + value: text, +} + +script.SpecialNumber = "NaN" / "-0" / "Infinity" / "-Infinity"; + +script.NumberValue = { + type: "number", + value: number / script.SpecialNumber, +} + +script.BooleanValue = { + type: "boolean", + value: bool, +} + +script.BigIntValue = { + type: "bigint", + value: text, +} + +script.RealmInfo = ( + script.WindowRealmInfo / + script.DedicatedWorkerRealmInfo / + script.SharedWorkerRealmInfo / + script.ServiceWorkerRealmInfo / + script.WorkerRealmInfo / + script.PaintWorkletRealmInfo / + script.AudioWorkletRealmInfo / + script.WorkletRealmInfo +) + +script.BaseRealmInfo = ( + realm: script.Realm, + origin: text +) + +script.WindowRealmInfo = { + script.BaseRealmInfo, + type: "window", + context: browsingContext.BrowsingContext, + ? sandbox: text +} + +script.DedicatedWorkerRealmInfo = { + script.BaseRealmInfo, + type: "dedicated-worker", + owners: [script.Realm] +} + +script.SharedWorkerRealmInfo = { + script.BaseRealmInfo, + type: "shared-worker" +} + +script.ServiceWorkerRealmInfo = { + script.BaseRealmInfo, + type: "service-worker" +} + +script.WorkerRealmInfo = { + script.BaseRealmInfo, + type: "worker" +} + +script.PaintWorkletRealmInfo = { + script.BaseRealmInfo, + type: "paint-worklet" +} + +script.AudioWorkletRealmInfo = { + script.BaseRealmInfo, + type: "audio-worklet" +} + +script.WorkletRealmInfo = { + script.BaseRealmInfo, + type: "worklet" +} + +script.RealmType = "window" / "dedicated-worker" / "shared-worker" / "service-worker" / + "worker" / "paint-worklet" / "audio-worklet" / "worklet" + + + +script.RemoteReference = ( + script.SharedReference / + script.RemoteObjectReference +) + +script.SharedReference = { + sharedId: script.SharedId + + ? handle: script.Handle, + Extensible +} + +script.RemoteObjectReference = { + handle: script.Handle, + + ? sharedId: script.SharedId + Extensible +} + +script.RemoteValue = ( + script.PrimitiveProtocolValue / + script.SymbolRemoteValue / + script.ArrayRemoteValue / + script.ObjectRemoteValue / + script.FunctionRemoteValue / + script.RegExpRemoteValue / + script.DateRemoteValue / + script.MapRemoteValue / + script.SetRemoteValue / + script.WeakMapRemoteValue / + script.WeakSetRemoteValue / + script.GeneratorRemoteValue / + script.ErrorRemoteValue / + script.ProxyRemoteValue / + script.PromiseRemoteValue / + script.TypedArrayRemoteValue / + script.ArrayBufferRemoteValue / + script.NodeListRemoteValue / + script.HTMLCollectionRemoteValue / + script.NodeRemoteValue / + script.WindowProxyRemoteValue +) + +script.ListRemoteValue = [*script.RemoteValue]; + +script.MappingRemoteValue = [*[(script.RemoteValue / text), script.RemoteValue]]; + +script.SymbolRemoteValue = { + type: "symbol", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ArrayRemoteValue = { + type: "array", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.ObjectRemoteValue = { + type: "object", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.MappingRemoteValue, +} + +script.FunctionRemoteValue = { + type: "function", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.RegExpRemoteValue = { + script.RegExpLocalValue, + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.DateRemoteValue = { + script.DateLocalValue, + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.MapRemoteValue = { + type: "map", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.MappingRemoteValue, +} + +script.SetRemoteValue = { + type: "set", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue +} + +script.WeakMapRemoteValue = { + type: "weakmap", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.WeakSetRemoteValue = { + type: "weakset", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.GeneratorRemoteValue = { + type: "generator", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ErrorRemoteValue = { + type: "error", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ProxyRemoteValue = { + type: "proxy", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.PromiseRemoteValue = { + type: "promise", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.TypedArrayRemoteValue = { + type: "typedarray", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ArrayBufferRemoteValue = { + type: "arraybuffer", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.NodeListRemoteValue = { + type: "nodelist", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.HTMLCollectionRemoteValue = { + type: "htmlcollection", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.NodeRemoteValue = { + type: "node", + ? sharedId: script.SharedId, + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.NodeProperties, +} + +script.NodeProperties = { + nodeType: js-uint, + childNodeCount: js-uint, + ? attributes: {*text => text}, + ? children: [*script.NodeRemoteValue], + ? localName: text, + ? mode: "open" / "closed", + ? namespaceURI: text, + ? nodeValue: text, + ? shadowRoot: script.NodeRemoteValue / null, +} + +script.WindowProxyRemoteValue = { + type: "window", + value: script.WindowProxyProperties, + ? handle: script.Handle, + ? internalId: script.InternalId +} + +script.WindowProxyProperties = { + context: browsingContext.BrowsingContext +} + +script.ResultOwnership = "root" / "none" + +script.SerializationOptions = { + ? maxDomDepth: (js-uint / null) .default 0, + ? maxObjectDepth: (js-uint / null) .default null, + ? includeShadowTree: ("none" / "open" / "all") .default "none", +} + +script.SharedId = text; + +script.StackFrame = { + columnNumber: js-uint, + functionName: text, + lineNumber: js-uint, + url: text, +} + +script.StackTrace = { + callFrames: [*script.StackFrame], +} + +script.Source = { + realm: script.Realm, + ? context: browsingContext.BrowsingContext +} + +script.AddPreloadScriptResult = { + script: script.PreloadScript +} + +script.DisownResult = EmptyResult + +script.CallFunctionResult = script.EvaluateResult + +script.GetRealmsResult = { + realms: [*script.RealmInfo] +} + +script.RemovePreloadScriptResult = EmptyResult + + script.Message = ( + method: "script.message", + params: script.MessageParameters + ) + +script.MessageParameters = { + channel: script.Channel, + data: script.RemoteValue, + source: script.Source, +} + +script.RealmCreated = ( + method: "script.realmCreated", + params: script.RealmInfo +) + +script.RealmDestroyed = ( + method: "script.realmDestroyed", + params: script.RealmDestroyedParameters +) + +script.RealmDestroyedParameters = { + realm: script.Realm +} + + +StorageResult = ( + storage.DeleteCookiesResult / + storage.GetCookiesResult / + storage.SetCookieResult +) + +storage.PartitionKey = { + ? userContext: text, + ? sourceOrigin: text, + Extensible, +} + +storage.GetCookiesResult = { + cookies: [*network.Cookie], + partitionKey: storage.PartitionKey, +} + +storage.SetCookieResult = { + partitionKey: storage.PartitionKey +} + +storage.DeleteCookiesResult = { + partitionKey: storage.PartitionKey +} + +LogEvent = ( + log.EntryAdded +) + +log.Level = "debug" / "info" / "warn" / "error" + +log.Entry = ( + log.GenericLogEntry / + log.ConsoleLogEntry / + log.JavascriptLogEntry +) + +log.BaseLogEntry = ( + level: log.Level, + source: script.Source, + text: text / null, + timestamp: js-uint, + ? stackTrace: script.StackTrace, +) + +log.GenericLogEntry = { + log.BaseLogEntry, + type: text, +} + +log.ConsoleLogEntry = { + log.BaseLogEntry, + type: "console", + method: text, + args: [*script.RemoteValue], +} + +log.JavascriptLogEntry = { + log.BaseLogEntry, + type: "javascript", +} + +log.EntryAdded = ( + method: "log.entryAdded", + params: log.Entry, +) + + +InputEvent = ( + input.FileDialogOpened +) + +input.PerformActionsResult = EmptyResult + +input.ReleaseActionsResult = EmptyResult + +input.SetFilesResult = EmptyResult + +input.FileDialogOpened = ( + method: "input.fileDialogOpened", + params: input.FileDialogInfo +) + +input.FileDialogInfo = { + context: browsingContext.BrowsingContext, + ? element: script.SharedReference, + multiple: bool, +} + +WebExtensionResult = ( + webExtension.InstallResult / + webExtension.UninstallResult +) + +webExtension.Extension = text + +webExtension.InstallResult = { + extension: webExtension.Extension +} + +webExtension.UninstallResult = EmptyResult diff --git a/common/bidi/spec/remote.cddl b/common/bidi/spec/remote.cddl new file mode 100644 index 0000000000000..a98859a021e12 --- /dev/null +++ b/common/bidi/spec/remote.cddl @@ -0,0 +1,1716 @@ +Command = { + id: js-uint, + CommandData, + Extensible, +} + +CommandData = ( + BrowserCommand // + BrowsingContextCommand // + EmulationCommand // + InputCommand // + NetworkCommand // + ScriptCommand // + SessionCommand // + StorageCommand // + WebExtensionCommand +) + +EmptyParams = { + Extensible +} + +Extensible = (*text => any) + +js-int = -9007199254740991..9007199254740991 +js-uint = 0..9007199254740991 + +SessionCommand = ( + session.End // + session.New // + session.Status // + session.Subscribe // + session.Unsubscribe +) + +session.CapabilitiesRequest = { + ? alwaysMatch: session.CapabilityRequest, + ? firstMatch: [*session.CapabilityRequest] +} + +session.CapabilityRequest = { + ? acceptInsecureCerts: bool, + ? browserName: text, + ? browserVersion: text, + ? platformName: text, + ? proxy: session.ProxyConfiguration, + ? unhandledPromptBehavior: session.UserPromptHandler, + Extensible +} + +session.ProxyConfiguration = { + session.AutodetectProxyConfiguration // + session.DirectProxyConfiguration // + session.ManualProxyConfiguration // + session.PacProxyConfiguration // + session.SystemProxyConfiguration +} + +session.AutodetectProxyConfiguration = ( + proxyType: "autodetect", + Extensible +) + +session.DirectProxyConfiguration = ( + proxyType: "direct", + Extensible +) + +session.ManualProxyConfiguration = ( + proxyType: "manual", + ? httpProxy: text, + ? sslProxy: text, + ? session.SocksProxyConfiguration, + ? noProxy: [*text], + Extensible +) + +session.SocksProxyConfiguration = ( + socksProxy: text, + socksVersion: 0..255, +) + +session.PacProxyConfiguration = ( + proxyType: "pac", + proxyAutoconfigUrl: text, + Extensible +) + +session.SystemProxyConfiguration = ( + proxyType: "system", + Extensible +) + + +session.UserPromptHandler = { + ? alert: session.UserPromptHandlerType, + ? beforeUnload: session.UserPromptHandlerType, + ? confirm: session.UserPromptHandlerType, + ? default: session.UserPromptHandlerType, + ? file: session.UserPromptHandlerType, + ? prompt: session.UserPromptHandlerType, +} + +session.UserPromptHandlerType = "accept" / "dismiss" / "ignore"; + +session.Subscription = text + +session.SubscribeParameters = { + events: [+text], + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +session.UnsubscribeByIDRequest = { + subscriptions: [+session.Subscription], +} + +session.UnsubscribeByAttributesRequest = { + events: [+text], +} + +session.Status = ( + method: "session.status", + params: EmptyParams, +) + +session.New = ( + method: "session.new", + params: session.NewParameters +) + +session.NewParameters = { + capabilities: session.CapabilitiesRequest +} + +session.End = ( + method: "session.end", + params: EmptyParams +) + + +session.Subscribe = ( + method: "session.subscribe", + params: session.SubscribeParameters +) + +session.Unsubscribe = ( + method: "session.unsubscribe", + params: session.UnsubscribeParameters, +) + +session.UnsubscribeParameters = session.UnsubscribeByAttributesRequest / session.UnsubscribeByIDRequest + +BrowserCommand = ( + browser.Close // + browser.CreateUserContext // + browser.GetClientWindows // + browser.GetUserContexts // + browser.RemoveUserContext // + browser.SetClientWindowState // + browser.SetDownloadBehavior +) + +browser.ClientWindow = text; + +browser.ClientWindowInfo = { + active: bool, + clientWindow: browser.ClientWindow, + height: js-uint, + state: "fullscreen" / "maximized" / "minimized" / "normal", + width: js-uint, + x: js-int, + y: js-int, +} + +browser.UserContext = text; + +browser.UserContextInfo = { + userContext: browser.UserContext +} + +browser.Close = ( + method: "browser.close", + params: EmptyParams, +) + +browser.CreateUserContext = ( + method: "browser.createUserContext", + params: browser.CreateUserContextParameters, +) + +browser.CreateUserContextParameters = { + ? acceptInsecureCerts: bool, + ? proxy: session.ProxyConfiguration, + ? unhandledPromptBehavior: session.UserPromptHandler +} + +browser.GetClientWindows = ( + method: "browser.getClientWindows", + params: EmptyParams, +) + +browser.GetUserContexts = ( + method: "browser.getUserContexts", + params: EmptyParams, +) + +browser.RemoveUserContext = ( + method: "browser.removeUserContext", + params: browser.RemoveUserContextParameters +) + +browser.RemoveUserContextParameters = { + userContext: browser.UserContext +} + +browser.SetClientWindowState = ( + method: "browser.setClientWindowState", + params: browser.SetClientWindowStateParameters +) + +browser.SetClientWindowStateParameters = { + clientWindow: browser.ClientWindow, + (browser.ClientWindowNamedState // browser.ClientWindowRectState) +} + +browser.ClientWindowNamedState = ( + state: "fullscreen" / "maximized" / "minimized" +) + +browser.ClientWindowRectState = ( + state: "normal", + ? width: js-uint, + ? height: js-uint, + ? x: js-int, + ? y: js-int, +) + +browser.SetDownloadBehavior = ( + method: "browser.setDownloadBehavior", + params: browser.SetDownloadBehaviorParameters +) + +browser.SetDownloadBehaviorParameters = { + downloadBehavior: browser.DownloadBehavior / null, + ? userContexts: [+browser.UserContext] +} + +browser.DownloadBehavior = { + ( + browser.DownloadBehaviorAllowed // + browser.DownloadBehaviorDenied + ) +} + +browser.DownloadBehaviorAllowed = ( + type: "allowed", + destinationFolder: text +) + +browser.DownloadBehaviorDenied = ( + type: "denied" +) + +BrowsingContextCommand = ( + browsingContext.Activate // + browsingContext.CaptureScreenshot // + browsingContext.Close // + browsingContext.Create // + browsingContext.GetTree // + browsingContext.HandleUserPrompt // + browsingContext.LocateNodes // + browsingContext.Navigate // + browsingContext.Print // + browsingContext.Reload // + browsingContext.SetViewport // + browsingContext.TraverseHistory +) + +browsingContext.BrowsingContext = text; + +browsingContext.Locator = ( + browsingContext.AccessibilityLocator / + browsingContext.CssLocator / + browsingContext.ContextLocator / + browsingContext.InnerTextLocator / + browsingContext.XPathLocator +) + +browsingContext.AccessibilityLocator = { + type: "accessibility", + value: { + ? name: text, + ? role: text, + } +} + +browsingContext.CssLocator = { + type: "css", + value: text +} + +browsingContext.ContextLocator = { + type: "context", + value: { + context: browsingContext.BrowsingContext, + } +} + +browsingContext.InnerTextLocator = { + type: "innerText", + value: text, + ? ignoreCase: bool + ? matchType: "full" / "partial", + ? maxDepth: js-uint, +} + +browsingContext.XPathLocator = { + type: "xpath", + value: text +} + +browsingContext.Navigation = text; + +browsingContext.ReadinessState = "none" / "interactive" / "complete" + +browsingContext.UserPromptType = "alert" / "beforeunload" / "confirm" / "prompt"; + +browsingContext.Activate = ( + method: "browsingContext.activate", + params: browsingContext.ActivateParameters +) + +browsingContext.ActivateParameters = { + context: browsingContext.BrowsingContext +} + +browsingContext.CaptureScreenshot = ( + method: "browsingContext.captureScreenshot", + params: browsingContext.CaptureScreenshotParameters +) + +browsingContext.CaptureScreenshotParameters = { + context: browsingContext.BrowsingContext, + ? origin: ("viewport" / "document") .default "viewport", + ? format: browsingContext.ImageFormat, + ? clip: browsingContext.ClipRectangle, +} + +browsingContext.ImageFormat = { + type: text, + ? quality: 0.0..1.0, +} + +browsingContext.ClipRectangle = ( + browsingContext.BoxClipRectangle / + browsingContext.ElementClipRectangle +) + +browsingContext.ElementClipRectangle = { + type: "element", + element: script.SharedReference +} + +browsingContext.BoxClipRectangle = { + type: "box", + x: float, + y: float, + width: float, + height: float +} + +browsingContext.Close = ( + method: "browsingContext.close", + params: browsingContext.CloseParameters +) + +browsingContext.CloseParameters = { + context: browsingContext.BrowsingContext, + ? promptUnload: bool .default false +} + +browsingContext.Create = ( + method: "browsingContext.create", + params: browsingContext.CreateParameters +) + +browsingContext.CreateType = "tab" / "window" + +browsingContext.CreateParameters = { + type: browsingContext.CreateType, + ? referenceContext: browsingContext.BrowsingContext, + ? background: bool .default false, + ? userContext: browser.UserContext +} + +browsingContext.GetTree = ( + method: "browsingContext.getTree", + params: browsingContext.GetTreeParameters +) + +browsingContext.GetTreeParameters = { + ? maxDepth: js-uint, + ? root: browsingContext.BrowsingContext, +} + +browsingContext.HandleUserPrompt = ( + method: "browsingContext.handleUserPrompt", + params: browsingContext.HandleUserPromptParameters +) + +browsingContext.HandleUserPromptParameters = { + context: browsingContext.BrowsingContext, + ? accept: bool, + ? userText: text, +} + +browsingContext.LocateNodes = ( + method: "browsingContext.locateNodes", + params: browsingContext.LocateNodesParameters +) + +browsingContext.LocateNodesParameters = { + context: browsingContext.BrowsingContext, + locator: browsingContext.Locator, + ? maxNodeCount: (js-uint .ge 1), + ? serializationOptions: script.SerializationOptions, + ? startNodes: [ + script.SharedReference ] +} + +browsingContext.Navigate = ( + method: "browsingContext.navigate", + params: browsingContext.NavigateParameters +) + +browsingContext.NavigateParameters = { + context: browsingContext.BrowsingContext, + url: text, + ? wait: browsingContext.ReadinessState, +} + +browsingContext.Print = ( + method: "browsingContext.print", + params: browsingContext.PrintParameters +) + +browsingContext.PrintParameters = { + context: browsingContext.BrowsingContext, + ? background: bool .default false, + ? margin: browsingContext.PrintMarginParameters, + ? orientation: ("portrait" / "landscape") .default "portrait", + ? page: browsingContext.PrintPageParameters, + ? pageRanges: [*(js-uint / text)], + ? scale: (0.1..2.0) .default 1.0, + ? shrinkToFit: bool .default true, +} + +browsingContext.PrintMarginParameters = { + ? bottom: (float .ge 0.0) .default 1.0, + ? left: (float .ge 0.0) .default 1.0, + ? right: (float .ge 0.0) .default 1.0, + ? top: (float .ge 0.0) .default 1.0, +} + +; Minimum size is 1pt x 1pt. Conversion follows from +; https://www.w3.org/TR/css3-values/#absolute-lengths +browsingContext.PrintPageParameters = { + ? height: (float .ge 0.0352) .default 27.94, + ? width: (float .ge 0.0352) .default 21.59, +} + +browsingContext.Reload = ( + method: "browsingContext.reload", + params: browsingContext.ReloadParameters +) + +browsingContext.ReloadParameters = { + context: browsingContext.BrowsingContext, + ? ignoreCache: bool, + ? wait: browsingContext.ReadinessState, +} + +browsingContext.SetViewport = ( + method: "browsingContext.setViewport", + params: browsingContext.SetViewportParameters +) + +browsingContext.SetViewportParameters = { + ? context: browsingContext.BrowsingContext, + ? viewport: browsingContext.Viewport / null, + ? devicePixelRatio: (float .gt 0.0) / null, + ? userContexts: [+browser.UserContext], +} + +browsingContext.Viewport = { + width: js-uint, + height: js-uint, +} + +browsingContext.TraverseHistory = ( + method: "browsingContext.traverseHistory", + params: browsingContext.TraverseHistoryParameters +) + +browsingContext.TraverseHistoryParameters = { + context: browsingContext.BrowsingContext, + delta: js-int, +} + +EmulationCommand = ( + emulation.SetForcedColorsModeThemeOverride // + emulation.SetGeolocationOverride // + emulation.SetLocaleOverride // + emulation.SetNetworkConditions // + emulation.SetScreenOrientationOverride // + emulation.SetScreenSettingsOverride // + emulation.SetScriptingEnabled // + emulation.SetScrollbarTypeOverride // + emulation.SetTimezoneOverride // + emulation.SetTouchOverride // + emulation.SetUserAgentOverride // + emulation.SetViewportMetaOverride +) + + +emulation.SetForcedColorsModeThemeOverride = ( + method: "emulation.setForcedColorsModeThemeOverride", + params: emulation.SetForcedColorsModeThemeOverrideParameters +) + +emulation.SetForcedColorsModeThemeOverrideParameters = { + theme: emulation.ForcedColorsModeTheme / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.ForcedColorsModeTheme = "light" / "dark" + +emulation.SetGeolocationOverride = ( + method: "emulation.setGeolocationOverride", + params: emulation.SetGeolocationOverrideParameters +) + +emulation.SetGeolocationOverrideParameters = { + ( + (coordinates: emulation.GeolocationCoordinates / null) // + (error: emulation.GeolocationPositionError) + ), + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.GeolocationCoordinates = { + latitude: -90.0..90.0, + longitude: -180.0..180.0, + ? accuracy: (float .ge 0.0) .default 1.0, + ? altitude: float / null .default null, + ? altitudeAccuracy: (float .ge 0.0) / null .default null, + ? heading: (0.0...360.0) / null .default null, + ? speed: (float .ge 0.0) / null .default null, +} + +emulation.GeolocationPositionError = { + type: "positionUnavailable" +} + +emulation.SetLocaleOverride = ( + method: "emulation.setLocaleOverride", + params: emulation.SetLocaleOverrideParameters +) + +emulation.SetLocaleOverrideParameters = { + locale: text / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetNetworkConditions = ( + method: "emulation.setNetworkConditions", + params: emulation.setNetworkConditionsParameters +) + +emulation.setNetworkConditionsParameters = { + networkConditions: emulation.NetworkConditions / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.NetworkConditions = emulation.NetworkConditionsOffline + +emulation.NetworkConditionsOffline = { + type: "offline" +} + +emulation.SetScreenSettingsOverride = ( + method: "emulation.setScreenSettingsOverride", + params: emulation.SetScreenSettingsOverrideParameters +) + +emulation.ScreenArea = { + width: js-uint, + height: js-uint +} + +emulation.SetScreenSettingsOverrideParameters = { + screenArea: emulation.ScreenArea / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetScreenOrientationOverride = ( + method: "emulation.setScreenOrientationOverride", + params: emulation.SetScreenOrientationOverrideParameters +) + +emulation.ScreenOrientationNatural = "portrait" / "landscape" +emulation.ScreenOrientationType = "portrait-primary" / "portrait-secondary" / "landscape-primary" / "landscape-secondary" + +emulation.ScreenOrientation = { + natural: emulation.ScreenOrientationNatural, + type: emulation.ScreenOrientationType +} + +emulation.SetScreenOrientationOverrideParameters = { + screenOrientation: emulation.ScreenOrientation / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetUserAgentOverride = ( + method: "emulation.setUserAgentOverride", + params: emulation.SetUserAgentOverrideParameters +) + +emulation.SetUserAgentOverrideParameters = { + userAgent: text / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetViewportMetaOverride = ( + method: "emulation.setViewportMetaOverride", + params: emulation.SetViewportMetaOverrideParameters +) + +emulation.SetViewportMetaOverrideParameters = { + viewportMeta: true / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetScriptingEnabled = ( + method: "emulation.setScriptingEnabled", + params: emulation.SetScriptingEnabledParameters +) + +emulation.SetScriptingEnabledParameters = { + enabled: false / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetScrollbarTypeOverride = ( + method: "emulation.setScrollbarTypeOverride", + params: emulation.SetScrollbarTypeOverrideParameters +) + +emulation.SetScrollbarTypeOverrideParameters = { + scrollbarType: "classic" / "overlay" / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetTimezoneOverride = ( + method: "emulation.setTimezoneOverride", + params: emulation.SetTimezoneOverrideParameters +) + +emulation.SetTimezoneOverrideParameters = { + timezone: text / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +emulation.SetTouchOverride = ( + method: "emulation.setTouchOverride", + params: emulation.SetTouchOverrideParameters +) + +emulation.SetTouchOverrideParameters = { + maxTouchPoints: (js-uint .ge 1) / null, + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + + +NetworkCommand = ( + network.AddDataCollector // + network.AddIntercept // + network.ContinueRequest // + network.ContinueResponse // + network.ContinueWithAuth // + network.DisownData // + network.FailRequest // + network.GetData // + network.ProvideResponse // + network.RemoveDataCollector // + network.RemoveIntercept // + network.SetCacheBehavior // + network.SetExtraHeaders +) + + +network.AuthCredentials = { + type: "password", + username: text, + password: text, +} + +network.BytesValue = network.StringValue / network.Base64Value; + +network.StringValue = { + type: "string", + value: text, +} + +network.Base64Value = { + type: "base64", + value: text, +} + +network.Collector = text + +network.CollectorType = "blob" + + +network.SameSite = "strict" / "lax" / "none" / "default" + + +network.Cookie = { + name: text, + value: network.BytesValue, + domain: text, + path: text, + size: js-uint, + httpOnly: bool, + secure: bool, + sameSite: network.SameSite, + ? expiry: js-uint, + Extensible, +} + +network.CookieHeader = { + name: text, + value: network.BytesValue, +} + +network.DataType = "request" / "response" + +network.Header = { + name: text, + value: network.BytesValue, +} + +network.Intercept = text + +network.Request = text; + + +network.SetCookieHeader = { + name: text, + value: network.BytesValue, + ? domain: text, + ? httpOnly: bool, + ? expiry: text, + ? maxAge: js-int, + ? path: text, + ? sameSite: network.SameSite, + ? secure: bool, +} + +network.UrlPattern = ( + network.UrlPatternPattern / + network.UrlPatternString +) + +network.UrlPatternPattern = { + type: "pattern", + ?protocol: text, + ?hostname: text, + ?port: text, + ?pathname: text, + ?search: text, +} + + +network.UrlPatternString = { + type: "string", + pattern: text, +} + + +network.AddDataCollector = ( + method: "network.addDataCollector", + params: network.AddDataCollectorParameters +) + +network.AddDataCollectorParameters = { + dataTypes: [+network.DataType], + maxEncodedDataSize: js-uint, + ? collectorType: network.CollectorType .default "blob", + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], +} + +network.AddIntercept = ( + method: "network.addIntercept", + params: network.AddInterceptParameters +) + +network.AddInterceptParameters = { + phases: [+network.InterceptPhase], + ? contexts: [+browsingContext.BrowsingContext], + ? urlPatterns: [*network.UrlPattern], +} + +network.InterceptPhase = "beforeRequestSent" / "responseStarted" / + "authRequired" + +network.ContinueRequest = ( + method: "network.continueRequest", + params: network.ContinueRequestParameters +) + +network.ContinueRequestParameters = { + request: network.Request, + ?body: network.BytesValue, + ?cookies: [*network.CookieHeader], + ?headers: [*network.Header], + ?method: text, + ?url: text, +} + +network.ContinueResponse = ( + method: "network.continueResponse", + params: network.ContinueResponseParameters +) + +network.ContinueResponseParameters = { + request: network.Request, + ?cookies: [*network.SetCookieHeader] + ?credentials: network.AuthCredentials, + ?headers: [*network.Header], + ?reasonPhrase: text, + ?statusCode: js-uint, +} + +network.ContinueWithAuth = ( + method: "network.continueWithAuth", + params: network.ContinueWithAuthParameters +) + +network.ContinueWithAuthParameters = { + request: network.Request, + (network.ContinueWithAuthCredentials // network.ContinueWithAuthNoCredentials) +} + +network.ContinueWithAuthCredentials = ( + action: "provideCredentials", + credentials: network.AuthCredentials +) + +network.ContinueWithAuthNoCredentials = ( + action: "default" / "cancel" +) + +network.DisownData = ( + method: "network.disownData", + params: network.disownDataParameters +) + +network.disownDataParameters = { + dataType: network.DataType, + collector: network.Collector, + request: network.Request, +} + +network.FailRequest = ( + method: "network.failRequest", + params: network.FailRequestParameters +) + +network.FailRequestParameters = { + request: network.Request, +} + +network.GetData = ( + method: "network.getData", + params: network.GetDataParameters +) + +network.GetDataParameters = { + dataType: network.DataType, + ? collector: network.Collector, + ? disown: bool .default false, + request: network.Request, +} + +network.ProvideResponse = ( + method: "network.provideResponse", + params: network.ProvideResponseParameters +) + +network.ProvideResponseParameters = { + request: network.Request, + ?body: network.BytesValue, + ?cookies: [*network.SetCookieHeader], + ?headers: [*network.Header], + ?reasonPhrase: text, + ?statusCode: js-uint, +} + +network.RemoveDataCollector = ( + method: "network.removeDataCollector", + params: network.RemoveDataCollectorParameters +) + +network.RemoveDataCollectorParameters = { + collector: network.Collector +} + +network.RemoveIntercept = ( + method: "network.removeIntercept", + params: network.RemoveInterceptParameters +) + +network.RemoveInterceptParameters = { + intercept: network.Intercept +} + +network.SetCacheBehavior = ( + method: "network.setCacheBehavior", + params: network.SetCacheBehaviorParameters +) + +network.SetCacheBehaviorParameters = { + cacheBehavior: "default" / "bypass", + ? contexts: [+browsingContext.BrowsingContext] +} + +network.SetExtraHeaders = ( + method: "network.setExtraHeaders", + params: network.SetExtraHeadersParameters +) + +network.SetExtraHeadersParameters = { + headers: [*network.Header] + ? contexts: [+browsingContext.BrowsingContext] + ? userContexts: [+browser.UserContext] +} + +ScriptCommand = ( + script.AddPreloadScript // + script.CallFunction // + script.Disown // + script.Evaluate // + script.GetRealms // + script.RemovePreloadScript +) + +script.Channel = text; + +script.ChannelValue = { + type: "channel", + value: script.ChannelProperties, +} + +script.ChannelProperties = { + channel: script.Channel, + ? serializationOptions: script.SerializationOptions, + ? ownership: script.ResultOwnership, +} + +script.EvaluateResult = ( + script.EvaluateResultSuccess / + script.EvaluateResultException +) + +script.EvaluateResultSuccess = { + type: "success", + result: script.RemoteValue, + realm: script.Realm +} + +script.EvaluateResultException = { + type: "exception", + exceptionDetails: script.ExceptionDetails + realm: script.Realm +} + +script.ExceptionDetails = { + columnNumber: js-uint, + exception: script.RemoteValue, + lineNumber: js-uint, + stackTrace: script.StackTrace, + text: text, +} + +script.Handle = text; + +script.InternalId = text; + +script.LocalValue = ( + script.RemoteReference / + script.PrimitiveProtocolValue / + script.ChannelValue / + script.ArrayLocalValue / + { script.DateLocalValue } / + script.MapLocalValue / + script.ObjectLocalValue / + { script.RegExpLocalValue } / + script.SetLocalValue +) + +script.ListLocalValue = [*script.LocalValue]; + +script.ArrayLocalValue = { + type: "array", + value: script.ListLocalValue, +} + +script.DateLocalValue = ( + type: "date", + value: text +) + +script.MappingLocalValue = [*[(script.LocalValue / text), script.LocalValue]]; + +script.MapLocalValue = { + type: "map", + value: script.MappingLocalValue, +} + +script.ObjectLocalValue = { + type: "object", + value: script.MappingLocalValue, +} + +script.RegExpValue = { + pattern: text, + ? flags: text, +} + +script.RegExpLocalValue = ( + type: "regexp", + value: script.RegExpValue, +) + +script.SetLocalValue = { + type: "set", + value: script.ListLocalValue, +} + +script.PreloadScript = text; + +script.Realm = text; + +script.PrimitiveProtocolValue = ( + script.UndefinedValue / + script.NullValue / + script.StringValue / + script.NumberValue / + script.BooleanValue / + script.BigIntValue +) + +script.UndefinedValue = { + type: "undefined", +} + +script.NullValue = { + type: "null", +} + +script.StringValue = { + type: "string", + value: text, +} + +script.SpecialNumber = "NaN" / "-0" / "Infinity" / "-Infinity"; + +script.NumberValue = { + type: "number", + value: number / script.SpecialNumber, +} + +script.BooleanValue = { + type: "boolean", + value: bool, +} + +script.BigIntValue = { + type: "bigint", + value: text, +} + +script.RealmType = "window" / "dedicated-worker" / "shared-worker" / "service-worker" / + "worker" / "paint-worklet" / "audio-worklet" / "worklet" + + + +script.RemoteReference = ( + script.SharedReference / + script.RemoteObjectReference +) + +script.SharedReference = { + sharedId: script.SharedId + + ? handle: script.Handle, + Extensible +} + +script.RemoteObjectReference = { + handle: script.Handle, + + ? sharedId: script.SharedId + Extensible +} + +script.RemoteValue = ( + script.PrimitiveProtocolValue / + script.SymbolRemoteValue / + script.ArrayRemoteValue / + script.ObjectRemoteValue / + script.FunctionRemoteValue / + script.RegExpRemoteValue / + script.DateRemoteValue / + script.MapRemoteValue / + script.SetRemoteValue / + script.WeakMapRemoteValue / + script.WeakSetRemoteValue / + script.GeneratorRemoteValue / + script.ErrorRemoteValue / + script.ProxyRemoteValue / + script.PromiseRemoteValue / + script.TypedArrayRemoteValue / + script.ArrayBufferRemoteValue / + script.NodeListRemoteValue / + script.HTMLCollectionRemoteValue / + script.NodeRemoteValue / + script.WindowProxyRemoteValue +) + +script.ListRemoteValue = [*script.RemoteValue]; + +script.MappingRemoteValue = [*[(script.RemoteValue / text), script.RemoteValue]]; + +script.SymbolRemoteValue = { + type: "symbol", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ArrayRemoteValue = { + type: "array", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.ObjectRemoteValue = { + type: "object", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.MappingRemoteValue, +} + +script.FunctionRemoteValue = { + type: "function", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.RegExpRemoteValue = { + script.RegExpLocalValue, + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.DateRemoteValue = { + script.DateLocalValue, + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.MapRemoteValue = { + type: "map", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.MappingRemoteValue, +} + +script.SetRemoteValue = { + type: "set", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue +} + +script.WeakMapRemoteValue = { + type: "weakmap", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.WeakSetRemoteValue = { + type: "weakset", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.GeneratorRemoteValue = { + type: "generator", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ErrorRemoteValue = { + type: "error", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ProxyRemoteValue = { + type: "proxy", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.PromiseRemoteValue = { + type: "promise", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.TypedArrayRemoteValue = { + type: "typedarray", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.ArrayBufferRemoteValue = { + type: "arraybuffer", + ? handle: script.Handle, + ? internalId: script.InternalId, +} + +script.NodeListRemoteValue = { + type: "nodelist", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.HTMLCollectionRemoteValue = { + type: "htmlcollection", + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.ListRemoteValue, +} + +script.NodeRemoteValue = { + type: "node", + ? sharedId: script.SharedId, + ? handle: script.Handle, + ? internalId: script.InternalId, + ? value: script.NodeProperties, +} + +script.NodeProperties = { + nodeType: js-uint, + childNodeCount: js-uint, + ? attributes: {*text => text}, + ? children: [*script.NodeRemoteValue], + ? localName: text, + ? mode: "open" / "closed", + ? namespaceURI: text, + ? nodeValue: text, + ? shadowRoot: script.NodeRemoteValue / null, +} + +script.WindowProxyRemoteValue = { + type: "window", + value: script.WindowProxyProperties, + ? handle: script.Handle, + ? internalId: script.InternalId +} + +script.WindowProxyProperties = { + context: browsingContext.BrowsingContext +} + +script.ResultOwnership = "root" / "none" + +script.SerializationOptions = { + ? maxDomDepth: (js-uint / null) .default 0, + ? maxObjectDepth: (js-uint / null) .default null, + ? includeShadowTree: ("none" / "open" / "all") .default "none", +} + +script.SharedId = text; + +script.StackFrame = { + columnNumber: js-uint, + functionName: text, + lineNumber: js-uint, + url: text, +} + +script.StackTrace = { + callFrames: [*script.StackFrame], +} + +script.RealmTarget = { + realm: script.Realm +} + +script.ContextTarget = { + context: browsingContext.BrowsingContext, + ? sandbox: text +} + +script.Target = ( + script.ContextTarget / + script.RealmTarget +) + +script.AddPreloadScript = ( + method: "script.addPreloadScript", + params: script.AddPreloadScriptParameters +) + +script.AddPreloadScriptParameters = { + functionDeclaration: text, + ? arguments: [*script.ChannelValue], + ? contexts: [+browsingContext.BrowsingContext], + ? userContexts: [+browser.UserContext], + ? sandbox: text +} + +script.Disown = ( + method: "script.disown", + params: script.DisownParameters +) + +script.DisownParameters = { + handles: [*script.Handle] + target: script.Target; +} + +script.CallFunction = ( + method: "script.callFunction", + params: script.CallFunctionParameters +) + +script.CallFunctionParameters = { + functionDeclaration: text, + awaitPromise: bool, + target: script.Target, + ? arguments: [*script.LocalValue], + ? resultOwnership: script.ResultOwnership, + ? serializationOptions: script.SerializationOptions, + ? this: script.LocalValue, + ? userActivation: bool .default false, +} + +script.Evaluate = ( + method: "script.evaluate", + params: script.EvaluateParameters +) + +script.EvaluateParameters = { + expression: text, + target: script.Target, + awaitPromise: bool, + ? resultOwnership: script.ResultOwnership, + ? serializationOptions: script.SerializationOptions, + ? userActivation: bool .default false, +} + +script.GetRealms = ( + method: "script.getRealms", + params: script.GetRealmsParameters +) + +script.GetRealmsParameters = { + ? context: browsingContext.BrowsingContext, + ? type: script.RealmType, +} + +script.RemovePreloadScript = ( + method: "script.removePreloadScript", + params: script.RemovePreloadScriptParameters +) + +script.RemovePreloadScriptParameters = { + script: script.PreloadScript +} + +StorageCommand = ( + storage.DeleteCookies // + storage.GetCookies // + storage.SetCookie +) + +storage.PartitionKey = { + ? userContext: text, + ? sourceOrigin: text, + Extensible, +} + +storage.GetCookies = ( + method: "storage.getCookies", + params: storage.GetCookiesParameters +) + + +storage.CookieFilter = { + ? name: text, + ? value: network.BytesValue, + ? domain: text, + ? path: text, + ? size: js-uint, + ? httpOnly: bool, + ? secure: bool, + ? sameSite: network.SameSite, + ? expiry: js-uint, + Extensible, +} + +storage.BrowsingContextPartitionDescriptor = { + type: "context", + context: browsingContext.BrowsingContext +} + +storage.StorageKeyPartitionDescriptor = { + type: "storageKey", + ? userContext: text, + ? sourceOrigin: text, + Extensible, +} + +storage.PartitionDescriptor = ( + storage.BrowsingContextPartitionDescriptor / + storage.StorageKeyPartitionDescriptor +) + +storage.GetCookiesParameters = { + ? filter: storage.CookieFilter, + ? partition: storage.PartitionDescriptor, +} + +storage.SetCookie = ( + method: "storage.setCookie", + params: storage.SetCookieParameters, +) + + +storage.PartialCookie = { + name: text, + value: network.BytesValue, + domain: text, + ? path: text, + ? httpOnly: bool, + ? secure: bool, + ? sameSite: network.SameSite, + ? expiry: js-uint, + Extensible, +} + +storage.SetCookieParameters = { + cookie: storage.PartialCookie, + ? partition: storage.PartitionDescriptor, +} + +storage.DeleteCookies = ( + method: "storage.deleteCookies", + params: storage.DeleteCookiesParameters, +) + +storage.DeleteCookiesParameters = { + ? filter: storage.CookieFilter, + ? partition: storage.PartitionDescriptor, +} + +InputCommand = ( + input.PerformActions // + input.ReleaseActions // + input.SetFiles +) + +InputResult = ( + input.PerformActionsResult / + input.ReleaseActionsResult / + input.SetFilesResult +) + +input.ElementOrigin = { + type: "element", + element: script.SharedReference +} + +input.PerformActions = ( + method: "input.performActions", + params: input.PerformActionsParameters +) + +input.PerformActionsParameters = { + context: browsingContext.BrowsingContext, + actions: [*input.SourceActions] +} + +input.SourceActions = ( + input.NoneSourceActions / + input.KeySourceActions / + input.PointerSourceActions / + input.WheelSourceActions +) + +input.NoneSourceActions = { + type: "none", + id: text, + actions: [*input.NoneSourceAction] +} + +input.NoneSourceAction = input.PauseAction + +input.KeySourceActions = { + type: "key", + id: text, + actions: [*input.KeySourceAction] +} + +input.KeySourceAction = ( + input.PauseAction / + input.KeyDownAction / + input.KeyUpAction +) + +input.PointerSourceActions = { + type: "pointer", + id: text, + ? parameters: input.PointerParameters, + actions: [*input.PointerSourceAction] +} + +input.PointerType = "mouse" / "pen" / "touch" + +input.PointerParameters = { + ? pointerType: input.PointerType .default "mouse" +} + +input.PointerSourceAction = ( + input.PauseAction / + input.PointerDownAction / + input.PointerUpAction / + input.PointerMoveAction +) + +input.WheelSourceActions = { + type: "wheel", + id: text, + actions: [*input.WheelSourceAction] +} + +input.WheelSourceAction = ( + input.PauseAction / + input.WheelScrollAction +) + +input.PauseAction = { + type: "pause", + ? duration: js-uint +} + +input.KeyDownAction = { + type: "keyDown", + value: text +} + +input.KeyUpAction = { + type: "keyUp", + value: text +} + +input.PointerUpAction = { + type: "pointerUp", + button: js-uint, +} + +input.PointerDownAction = { + type: "pointerDown", + button: js-uint, + input.PointerCommonProperties +} + +input.PointerMoveAction = { + type: "pointerMove", + x: float, + y: float, + ? duration: js-uint, + ? origin: input.Origin, + input.PointerCommonProperties +} + +input.WheelScrollAction = { + type: "scroll", + x: js-int, + y: js-int, + deltaX: js-int, + deltaY: js-int, + ? duration: js-uint, + ? origin: input.Origin .default "viewport", +} + +input.PointerCommonProperties = ( + ? width: js-uint .default 1, + ? height: js-uint .default 1, + ? pressure: float .default 0.0, + ? tangentialPressure: float .default 0.0, + ? twist: (0..359) .default 0, + ; 0 .. Math.PI / 2 + ? altitudeAngle: (0.0..1.5707963267948966) .default 0.0, + ; 0 .. 2 * Math.PI + ? azimuthAngle: (0.0..6.283185307179586) .default 0.0, +) + +input.Origin = "viewport" / "pointer" / input.ElementOrigin + +input.ReleaseActions = ( + method: "input.releaseActions", + params: input.ReleaseActionsParameters +) + +input.ReleaseActionsParameters = { + context: browsingContext.BrowsingContext, +} + +input.SetFiles = ( + method: "input.setFiles", + params: input.SetFilesParameters +) + +input.SetFilesParameters = { + context: browsingContext.BrowsingContext, + element: script.SharedReference, + files: [*text] +} + +input.FileDialogOpened = ( + method: "input.fileDialogOpened", + params: input.FileDialogInfo +) + +input.FileDialogInfo = { + context: browsingContext.BrowsingContext, + ? element: script.SharedReference, + multiple: bool, +} + +WebExtensionCommand = ( + webExtension.Install // + webExtension.Uninstall +) + +webExtension.Extension = text + +webExtension.Install = ( + method: "webExtension.install", + params: webExtension.InstallParameters +) + +webExtension.InstallParameters = { + extensionData: webExtension.ExtensionData, +} + +webExtension.ExtensionData = ( + webExtension.ExtensionArchivePath / + webExtension.ExtensionBase64Encoded / + webExtension.ExtensionPath +) + +webExtension.ExtensionPath = { + type: "path", + path: text, +} + +webExtension.ExtensionArchivePath = { + type: "archivePath", + path: text, +} + +webExtension.ExtensionBase64Encoded = { + type: "base64", + value: text, +} + +webExtension.Uninstall = ( + method: "webExtension.uninstall", + params: webExtension.UninstallParameters +) + +webExtension.UninstallParameters = { + extension: webExtension.Extension, +} diff --git a/py/AGENTS.md b/py/AGENTS.md index 57a5e819e1a8e..27c1aaac41e9a 100644 --- a/py/AGENTS.md +++ b/py/AGENTS.md @@ -51,24 +51,25 @@ def method(param: str | None) -> int | None: pass # Avoid -from typing import Optional def method(param: Optional[str]) -> Optional[int]: pass ``` ### Python version -Code must work with Python 3.10 or later. Use modern syntax features available in 3.10+. +Code must work with Python 3.10 or later. Use modern syntax features available in 3.10+: -See the **Type hints** section for guidance on preferred type annotation syntax (including unions). +- Use `|` for union types instead of `Union[]` +- Use `X | None` instead of `Optional[X]` -For testing: use `bazel test //py/...` which employs a hermetic Python 3.10+ toolchain (see `/AGENTS.md`). - -For ad-hoc scripts, check your Python version locally before running: +When running tests or code in the terminal, explicitly use `python3.10` or later: ```bash -python --version -# Ensure you have 3.10+; on macOS/Linux use python3.10+ or on Windows py -3.10 +# Use explicitly +python3.10 -c "..." +python3.11 -c "..." + +# Avoid relying on `python3` as it may be 3.9 or earlier ``` ### Documentation diff --git a/py/BUILD.bazel b/py/BUILD.bazel index 4774b00a8f005..b560f297b7537 100644 --- a/py/BUILD.bazel +++ b/py/BUILD.bazel @@ -10,6 +10,7 @@ load("@rules_python//sphinxdocs:sphinx.bzl", "sphinx_build_binary", "sphinx_docs load("//common:defs.bzl", "copy_file") load("//py:defs.bzl", "generate_devtools", "py_test_suite") load("//py/private:browsers.bzl", "BROWSERS") +load("//py/private:generate_bidi.bzl", "generate_bidi") load("//py/private:import.bzl", "py_import") exports_files( @@ -582,6 +583,12 @@ py_binary( deps = [requirement("inflection")], ) +py_binary( + name = "generate_bidi", + srcs = ["generate_bidi.py"], + srcs_version = "PY3", +) + [generate_devtools( name = "create-cdp-srcs-{}".format(devtools_version), browser_protocol = "//common/devtools/chromium/{}:browser_protocol".format(devtools_version), @@ -591,6 +598,17 @@ py_binary( protocol_version = devtools_version, ) for devtools_version in BROWSER_VERSIONS] +# Pilot BiDi code generation from CDDL specification +generate_bidi( + name = "create-bidi-src", + cddl_file = "//common/bidi/spec:all.cddl", + enhancements_manifest = "//py/private:bidi_enhancements_manifest.py", + extra_srcs = ["//py/private:cdp.py"], + generator = ":generate_bidi", + module_name = "selenium/webdriver/common/bidi", + spec_version = "1.0", +) + py_test_suite( name = "unit", size = "small", @@ -770,6 +788,7 @@ BROWSER_TESTS = { ] ] + test_suite( name = "test-remote", tags = ["remote"], diff --git a/py/conftest.py b/py/conftest.py index 7cdb2446e4107..ff60cf06f08ec 100644 --- a/py/conftest.py +++ b/py/conftest.py @@ -118,6 +118,14 @@ def pytest_addoption(parser): metavar="DRIVER", help="Driver to run tests against ({})".format(", ".join(drivers)), ) + parser.addoption( + "--browser", + action="append", + choices=drivers, + dest="drivers", + metavar="BROWSER", + help="Browser to run tests against (alias for --driver)", + ) parser.addoption( "--browser-binary", action="store", diff --git a/py/generate_bidi.py b/py/generate_bidi.py new file mode 100755 index 0000000000000..32d19ec83cec9 --- /dev/null +++ b/py/generate_bidi.py @@ -0,0 +1,1964 @@ +#!/usr/bin/env python3.10 +""" +Generate Python WebDriver BiDi command modules from CDDL specification. + +This generator reads CDDL (Concise Data Definition Language) specification files +and produces Python type definitions and command classes that conform to the +WebDriver BiDi protocol. + +Usage: + python generate_bidi.py + +Example: + python generate_bidi.py local.cddl ./selenium/webdriver/common/bidi 1.0 +""" + +import argparse +import importlib.util +import logging +import re +import sys +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from textwrap import indent as tw_indent +from typing import Any + +__version__ = "1.0.0" + +# Logging setup +log_level = logging.INFO +logging.basicConfig(level=log_level) +logger = logging.getLogger("generate_bidi") + +# File headers +SHARED_HEADER = """# DO NOT EDIT THIS FILE! +# +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules.""" + +MODULE_HEADER = f"""{SHARED_HEADER} +# +# WebDriver BiDi module: {{}} +from __future__ import annotations + +from typing import Any +""" + + +def indent(s: str, n: int) -> str: + """Indent a string by n spaces.""" + return tw_indent(s, n * " ") + + +def load_enhancements_manifest(manifest_path: str | None) -> dict[str, Any]: + """Load enhancement manifest from a Python file. + + Args: + manifest_path: Path to Python file containing ENHANCEMENTS dict + + Returns: + Dictionary with enhancement rules, or empty dict if no manifest provided + """ + if not manifest_path: + return {} + + manifest_file = Path(manifest_path) + if not manifest_file.exists(): + logger.warning(f"Enhancement manifest not found: {manifest_path}") + return {} + + try: + spec = importlib.util.spec_from_file_location( + "bidi_enhancements", manifest_file + ) + if spec is None or spec.loader is None: + logger.warning(f"Could not load manifest: {manifest_path}") + return {} + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + enhancements = getattr(module, "ENHANCEMENTS", {}) + dataclass_methods = getattr(module, "DATACLASS_METHOD_TEMPLATES", {}) + method_docstrings = getattr(module, "DATACLASS_METHOD_DOCSTRINGS", {}) + + logger.info(f"Loaded enhancement manifest from: {manifest_path}") + logger.debug(f"Enhancements for modules: {list(enhancements.keys())}") + + return { + "enhancements": enhancements, + "dataclass_methods": dataclass_methods, + "method_docstrings": method_docstrings, + } + except Exception as e: + logger.error(f"Failed to load enhancement manifest: {e}", exc_info=True) + return {} + + +class CddlType(Enum): + """CDDL type mappings to Python types.""" + + TSTR = "str" # text string + TEXT = "str" # text (alias) + UINT = "int" # unsigned integer + INT = "int" # signed integer + NINT = "int" # negative integer + BOOL = "bool" # boolean + NULL = "None" # null + ANY = "Any" # any type + + @classmethod + def get_annotation(cls, cddl_type: str) -> str: + """Get Python type annotation for a CDDL type.""" + cddl_type = cddl_type.strip().lower() + + # Handle basic types + for member in cls: + if cddl_type == member.name.lower(): + return member.value + + # Handle composite types + if cddl_type.startswith("["): # Array + inner = cddl_type.strip("[]+ ") + inner_type = cls.get_annotation(inner) + return f"list[{inner_type}]" + + if cddl_type.startswith("{"): # Map/Dict + return "dict[str, Any]" + + # Default to Any for unknown types + return "Any" + + +@dataclass +class CddlCommand: + """Represents a CDDL command definition.""" + + module: str + name: str + params: dict[str, str] = field(default_factory=dict) + required_params: set[str] = field(default_factory=set) + result: str | None = None + description: str = "" + + def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: + """Generate Python method code for this command. + + Args: + enhancements: Dictionary with enhancement rules for this method + """ + enhancements = enhancements or {} + method_name = self._camel_to_snake(self.name) + + # Build parameter list with type hints + # Check if there's a params_override for user-friendly named arguments + params_to_use = self.params + if "params_override" in enhancements: + params_to_use = enhancements["params_override"] + + param_strs = [] + param_names = [] # Keep track of parameter names for later use + for param_name, param_type in params_to_use.items(): + if param_type in ["bool", "str", "int"]: + python_type = param_type + else: + python_type = CddlType.get_annotation(param_type) + snake_param = self._camel_to_snake(param_name) + param_names.append((param_name, snake_param)) + param_strs.append(f"{snake_param}: {python_type} | None = None") + + if param_strs: + # Check if full signature would exceed line length limit (120 chars) + single_line_signature = ( + f" def {method_name}(self, {', '.join(param_strs)}):" + ) + if len(single_line_signature) > 120: + # Format parameters on multiple lines + body = f" def {method_name}(\n" + body += " self,\n" + for i, param_str in enumerate(param_strs): + if i < len(param_strs) - 1: + body += f" {param_str},\n" + else: + body += f" {param_str},\n" + body += " ):\n" + else: + param_list = "self, " + ", ".join(param_strs) + body = f" def {method_name}({param_list}):\n" + else: + body = f" def {method_name}(self):\n" + body += f' """{self.description or "Execute " + self.module + "." + self.name}."""\n' + + # Add automatic validation for required parameters + # (This is used unless there's no required_params, in which case all params are optional) + if self.required_params: + method_snake = self._camel_to_snake(self.name) + for param_name, snake_param in param_names: + if param_name in self.required_params: + body += f" if {snake_param} is None:\n" + msg = f"{method_snake}() missing required argument:" + body += ( + f' raise TypeError("{msg} {{{{snake_param!r}}}}")\n' + ) + body += "\n" + + # Add validation if specified in enhancements (for additional business logic validation) + if "validate" in enhancements: + validate_func = enhancements["validate"] + # Build parameter list for validation function + param_args = ", ".join(f"{snake}={snake}" for _, snake in param_names) + body += f" {validate_func}({param_args})\n" + body += "\n" + + # Add transformation and preprocessing + # First, check if any transform is needed + if "transform" in enhancements: + transform_spec = enhancements["transform"] + if isinstance(transform_spec, dict): + # New format with explicit function and result parameter + transform_func = transform_spec.get("func") + result_param = transform_spec.get("result_param", "params") + input_params = [ + transform_spec.get(k) + for k in ["allowed", "destination_folder"] + if transform_spec.get(k) + ] + + if transform_func and result_param: + body += f" {result_param} = None\n" + param_args = ", ".join(input_params) + body += f" {result_param} = {transform_func}({param_args})\n" + body += "\n" + else: + # Legacy format for backward compatibility + transform_func = transform_spec + if self.name == "setDownloadBehavior": + body += " download_behavior = None\n" + body += f" download_behavior = {transform_func}(allowed, destination_folder)\n" + body += "\n" + + # Add preprocessing for serialization (check for to_bidi_dict method) + if "preprocess" in enhancements: + preprocess_rules = enhancements["preprocess"] + for param_name, preprocess_type in preprocess_rules.items(): + snake_param = self._camel_to_snake(param_name) + if preprocess_type == "check_serialize_method": + body += f" if {snake_param} and hasattr({snake_param}, 'to_bidi_dict'):\n" + body += ( + f" {snake_param} = {snake_param}.to_bidi_dict()\n" + ) + body += "\n" + + # Build params dict + body += " params = {\n" + + # If there's a transform with a result parameter, map it to the BiDi protocol name + if "transform" in enhancements and isinstance(enhancements["transform"], dict): + transform_spec = enhancements["transform"] + result_param = transform_spec.get("result_param") + + # Map the result parameter to the original CDDL parameter name + if result_param == "download_behavior": + body += ' "downloadBehavior": download_behavior,\n' + # Add remaining parameters that weren't part of the transform + for cddl_param_name in self.params: + if cddl_param_name not in ["downloadBehavior"]: + snake_name = self._camel_to_snake(cddl_param_name) + body += f' "{cddl_param_name}": {snake_name},\n' + else: + # Standard parameter mapping from CDDL + for param_name, snake_param in param_names: + body += f' "{param_name}": {snake_param},\n' + + body += " }\n" + body += " params = {k: v for k, v in params.items() if v is not None}\n" + body += f' cmd = command_builder("{self.module}.{self.name}", params)\n' + body += " result = self._conn.execute(cmd)\n" + + # Add response handling for extraction/deserialization + if "extract_field" in enhancements: + extract_field = enhancements["extract_field"] + extract_property = enhancements.get("extract_property") + + # Check if we also need to deserialize the extracted field + deserialize_rules = enhancements.get("deserialize", {}) + + if extract_property: + # Extract property from list items + body += f' if result and "{extract_field}" in result:\n' + body += f' items = result.get("{extract_field}", [])\n' + body += " return [\n" + body += f' item.get("{extract_property}")\n' + body += " for item in items\n" + body += " if isinstance(item, dict)\n" + body += " ]\n" + body += " return []\n" + elif extract_field in deserialize_rules: + # Extract field and deserialize to typed objects + type_name = deserialize_rules[extract_field] + body += f' if result and "{extract_field}" in result:\n' + body += f' items = result.get("{extract_field}", [])\n' + body += " return [\n" + body += f" {type_name}(\n" + body += self._generate_field_args(extract_field, type_name) + body += " )\n" + body += " for item in items\n" + body += " if isinstance(item, dict)\n" + body += " ]\n" + body += " return []\n" + else: + # Simple field extraction (return the value directly, not wrapped in result dict) + body += f' if result and "{extract_field}" in result:\n' + body += f' extracted = result.get("{extract_field}")\n' + body += " return extracted\n" + body += " return result\n" + elif "deserialize" in enhancements: + # Deserialize response to typed objects (legacy, without extract_field) + deserialize_rules = enhancements["deserialize"] + for response_field, type_name in deserialize_rules.items(): + body += f' if result and "{response_field}" in result:\n' + body += f' items = result.get("{response_field}", [])\n' + body += " return [\n" + body += f" {type_name}(\n" + body += self._generate_field_args(response_field, type_name) + body += " )\n" + body += " for item in items\n" + body += " if isinstance(item, dict)\n" + body += " ]\n" + body += " return []\n" + else: + # No special response handling, just return the result + body += " return result\n" + + return body + + def _generate_field_args(self, response_field: str, type_name: str) -> str: + """Generate constructor arguments for deserializing response objects. + + For now, this handles ClientWindowInfo and Info specifically. + Could be extended to be more generic. + """ + if type_name == "ClientWindowInfo": + return ( + ' active=item.get("active"),\n' + ' client_window=item.get("clientWindow"),\n' + ' height=item.get("height"),\n' + ' state=item.get("state"),\n' + ' width=item.get("width"),\n' + ' x=item.get("x"),\n' + ' y=item.get("y")\n' + ) + elif type_name == "Info": + return ( + ' children=_deserialize_info_list(item.get("children", [])),\n' + ' client_window=item.get("clientWindow"),\n' + ' context=item.get("context"),\n' + ' original_opener=item.get("originalOpener"),\n' + ' url=item.get("url"),\n' + ' user_context=item.get("userContext"),\n' + ' parent=item.get("parent")\n' + ) + # For other types, return empty + return "" + + @staticmethod + def _camel_to_snake(name: str) -> str: + """Convert camelCase to snake_case.""" + name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() + + +@dataclass +class CddlTypeDefinition: + """Represents a CDDL type definition.""" + + module: str + name: str + fields: dict[str, str] = field(default_factory=dict) + description: str = "" + + def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str: + """Generate Python dataclass code for this type. + + Args: + enhancements: Dictionary containing dataclass_methods and method_docstrings + """ + enhancements = enhancements or {} + dataclass_methods = enhancements.get("dataclass_methods", {}) + method_docstrings = enhancements.get("method_docstrings", {}) + + # Generate class name from type name (keep it as-is, don't split on underscores) + class_name = self.name + code = "@dataclass\n" + code += f"class {class_name}:\n" + code += f' """{self.description or self.name}."""\n\n' + + if not self.fields: + code += " pass\n" + else: + for field_name, field_type in self.fields.items(): + # Convert CDDL type to Python type + python_type = self._get_python_type(field_type) + snake_name = CddlCommand._camel_to_snake(field_name) + + # Check if the CDDL field type is a quoted string literal (e.g., type: "key") + # These are discriminant fields: auto-populate and exclude from __init__ + # so callers don't need to pass them as positional or keyword arguments. + literal_match = re.match(r'^"([^"]+)"$', field_type.strip()) + if literal_match: + literal_value = literal_match.group(1) + code += f' {snake_name}: str = field(default="{literal_value}", init=False)\n' + # Check if this field is a list type (using lowercase 'list[' from Python 3.10+ syntax) + elif python_type.startswith("list["): + # Remove the trailing ' | None' from list types since default_factory=list ensures non-None + type_annotation = python_type.replace(" | None", "") + code += f" {snake_name}: {type_annotation} = field(default_factory=list)\n" + # Check if this field is a dict type (using lowercase 'dict[' from Python 3.10+ syntax) + elif python_type.startswith("dict["): + # Remove the trailing ' | None' from dict types since default_factory=dict ensures non-None + type_annotation = python_type.replace(" | None", "") + code += f" {snake_name}: {type_annotation} = field(default_factory=dict)\n" + else: + code += f" {snake_name}: {python_type} = None\n" + + # Add custom methods if defined for this class + if class_name in dataclass_methods: + code += "\n" + methods_dict = dataclass_methods[class_name] + docstrings_dict = method_docstrings.get(class_name, {}) + + for method_name in methods_dict: + method_impl = methods_dict[method_name] + docstring = docstrings_dict.get(method_name, "") + code += f" def {method_name}(self):\n" + if docstring: + code += f' """{docstring}"""\n' + code += f" {method_impl}\n" + code += "\n" + + return code + + @staticmethod + def _get_python_type(cddl_type: str) -> str: + """Convert CDDL type to Python type annotation using Python 3.10+ syntax.""" + cddl_type = cddl_type.strip().lower() + + # Handle basic types + type_mapping = { + "tstr": "str", + "text": "str", + "uint": "int", + "int": "int", + "nint": "int", + "bool": "bool", + "null": "None", + } + + for cddl, python in type_mapping.items(): + if cddl_type == cddl: + # Use Python 3.10+ union syntax: type | None + return f"{python} | None" + + # Handle arrays + if cddl_type.startswith("["): + inner = cddl_type.strip("[]+ ") + inner_type = CddlTypeDefinition._get_python_type(inner) + # Remove " | None" from inner type since it might be wrapped + if " | None" in inner_type: + inner_base = inner_type.replace(" | None", "") + return f"list[{inner_base} | None] | None" + return f"list[{inner_type}] | None" + + # Handle maps/dicts + if cddl_type.startswith("{"): + return "dict[str, Any] | None" + + # Default to Any for unknown/complex types + return "Any | None" + + +@dataclass +class CddlEnum: + """Represents a CDDL enum definition (string union).""" + + module: str + name: str + values: list[str] = field(default_factory=list) + description: str = "" + + def to_python_class(self) -> str: + """Generate Python enum class code. + + Generates a simple class with string constants to match the existing + pattern in the codebase (e.g., ClientWindowState). + """ + class_name = self.name + code = f"class {class_name}:\n" + code += f' """{self.description or self.name}."""\n\n' + + for value in self.values: + # Convert value to UPPER_SNAKE_CASE constant name + const_name = self._value_to_const_name(value) + code += f' {const_name} = "{value}"\n' + + return code + + @staticmethod + def _value_to_const_name(value: str) -> str: + """Convert enum string value to constant name. + + Examples: + "none" -> "NONE" + "portrait-primary" -> "PORTRAIT_PRIMARY" + "interactive" -> "INTERACTIVE" + """ + # Replace hyphens with underscores + const_name = value.replace("-", "_") + # Convert to uppercase + return const_name.upper() + + +@dataclass +class CddlEvent: + """Represents a CDDL event definition (incoming message from browser).""" + + module: str + name: str + method: str + params_type: str + description: str = "" + + def to_python_dataclass(self) -> str: + """Generate Python dataclass code for the event info type. + + Returns a dataclass code that attempts to use globals().get() for safety. + """ + class_name = self.name + + # Extract the type name from params_type (e.g., "browsingContext.Info" -> "Info") + # The params_type comes from the CDDL and includes module prefix + type_name = ( + self.params_type.split(".")[-1] + if "." in self.params_type + else self.params_type + ) + + # Special case: if the type is BaseNavigationInfo, use BaseNavigationInfo directly + # (NavigationInfo will be created as an alias to it) + if type_name == "NavigationInfo": + type_name = "BaseNavigationInfo" + + # Generate type alias using globals().get() for safety + code = f"# Event: {self.method}\n" + code += f"{class_name} = globals().get('{type_name}', dict) # Fallback to dict if type not defined\n" + + return code + + +@dataclass +class CddlModule: + """Represents a CDDL module (e.g., script, network, browsing_context).""" + + name: str + commands: list[CddlCommand] = field(default_factory=list) + types: list[CddlTypeDefinition] = field(default_factory=list) + enums: list[CddlEnum] = field(default_factory=list) + events: list[CddlEvent] = field(default_factory=list) + + @staticmethod + def _convert_method_to_event_name(method_suffix: str) -> str: + """Convert BiDi method suffix to friendly event name. + + Examples: + "contextCreated" -> "context_created" + "navigationStarted" -> "navigation_started" + "userPromptOpened" -> "user_prompt_opened" + """ + # Convert camelCase to snake_case + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", method_suffix) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + + def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: + """Generate Python code for this module. + + Args: + enhancements: Dictionary with module-level enhancements + """ + enhancements = enhancements or {} + code = MODULE_HEADER.format(self.name) + + # Collect needed imports to avoid duplicates + needs_command_builder = bool(self.commands) + needs_dataclass = self.commands or self.types or self.events + needs_threading = self.events + needs_callable = self.events + needs_session = self.events + + # Add imports (field import will be added conditionally after code generation) + if needs_command_builder: + code += "from .common import command_builder\n" + if needs_dataclass: + code += "from dataclasses import dataclass\n" + if needs_threading: + code += "import threading\n" + if needs_callable: + code += "from collections.abc import Callable\n" + if needs_session: + code += "from selenium.webdriver.common.bidi.session import Session\n" + + code += "\n\n" + + # Add helper function definitions from enhancements + # Collect all referenced helper functions (validate, transform) + helper_funcs_to_add = set() + for cmd in self.commands: + method_name_snake = cmd._camel_to_snake(cmd.name) + method_enhancements = enhancements.get(method_name_snake, {}) + if "validate" in method_enhancements: + helper_funcs_to_add.add(("validate", method_enhancements["validate"])) + if "transform" in method_enhancements and isinstance( + method_enhancements["transform"], dict + ): + transform_spec = method_enhancements["transform"] + if "func" in transform_spec: + helper_funcs_to_add.add(("transform", transform_spec["func"])) + + # Generate helper functions if needed + if helper_funcs_to_add: + for func_type, func_name in sorted(helper_funcs_to_add): + if ( + func_type == "validate" + and func_name == "validate_download_behavior" + ): + code += """def validate_download_behavior( + allowed: bool | None, + destination_folder: str | None, + user_contexts: Any | None = None, +) -> None: + \"\"\"Validate download behavior parameters. + + Args: + allowed: Whether downloads are allowed + destination_folder: Destination folder for downloads + user_contexts: Optional list of user contexts + + Raises: + ValueError: If parameters are invalid + \"\"\" + if allowed is True and not destination_folder: + raise ValueError("destination_folder is required when allowed=True") + if allowed is False and destination_folder: + raise ValueError("destination_folder should not be provided when allowed=False") + + +""" + elif ( + func_type == "transform" + and func_name == "transform_download_params" + ): + code += """def transform_download_params( + allowed: bool | None, + destination_folder: str | None, +) -> dict[str, Any] | None: + \"\"\"Transform download parameters into download_behavior object. + + Args: + allowed: Whether downloads are allowed + destination_folder: Destination folder for downloads (accepts str or + pathlib.Path; will be coerced to str) + + Returns: + Dictionary representing the download_behavior object, or None if allowed is None + \"\"\" + if allowed is True: + return { + "type": "allowed", + # Coerce pathlib.Path (or any path-like) to str so the BiDi + # protocol always receives a plain JSON string. + "destinationFolder": str(destination_folder) if destination_folder is not None else None, + } + elif allowed is False: + return {"type": "denied"} + else: # None — reset to browser default (sent as JSON null) + return None + + +""" + + # Generate enums first (excluding those in exclude_types) + exclude_types = set(enhancements.get("exclude_types", [])) + + # Also exclude any types that have extra_dataclasses overrides + # Extract class names from extra_dataclasses strings + for extra_cls in enhancements.get("extra_dataclasses", []): + # Match "class ClassName:" pattern + match = re.search(r"class\s+(\w+)\s*:", extra_cls) + if match: + exclude_types.add(match.group(1)) + + for enum_def in self.enums: + if enum_def.name in exclude_types: + continue + code += enum_def.to_python_class() + code += "\n\n" + + # Emit module-level aliases from enhancement manifest (e.g. LogLevel = Level) + for alias, target in enhancements.get("aliases", {}).items(): + code += f"{alias} = {target}\n\n" + + # Generate type dataclasses, skipping any overridden by extra_dataclasses + for type_def in self.types: + if type_def.name in exclude_types: + continue + code += type_def.to_python_dataclass(enhancements) + code += "\n\n" + + # Emit extra dataclasses from enhancement manifest (non-CDDL additions) + for extra_cls in enhancements.get("extra_dataclasses", []): + code += extra_cls + code += "\n\n" + + # Emit extra type aliases from enhancement manifest (e.g., union types for events) + for extra_alias in enhancements.get("extra_type_aliases", []): + code += extra_alias + code += "\n\n" + + # NOTE: Don't generate event type aliases here - they reference types that may not be defined yet + # They will be generated after the class definition instead + + # Generate EVENT_NAME_MAPPING for the module (before the module class) + if self.events: + # Generate EVENT_NAME_MAPPING for the module + code += "# BiDi Event Name to Parameter Type Mapping\n" + code += "EVENT_NAME_MAPPING = {\n" + for event_def in self.events: + # Convert method name to user-friendly event name + # e.g., "browsingContext.contextCreated" -> "context_created" + method_parts = event_def.method.split(".") + if len(method_parts) == 2: + event_name = self._convert_method_to_event_name(method_parts[1]) + code += f' "{event_name}": "{event_def.method}",\n' + # Extra events not in the CDDL spec (e.g. Chromium-specific events) + for extra_evt in enhancements.get("extra_events", []): + code += ( + f' "{extra_evt["event_key"]}": "{extra_evt["bidi_event"]}",\n' + ) + code += "}\n\n" + + # Add custom method function definitions before the class (for browsingContext) + if self.name == "browsingContext": + # Add helper function for recursive Info deserialization + code += """def _deserialize_info_list(items: list) -> list | None: + \"\"\"Recursively deserialize a list of dicts to Info objects. + + Args: + items: List of dicts from the API response + + Returns: + List of Info objects with properly nested children, or None if empty + \"\"\" + if not items or not isinstance(items, list): + return None + + result = [] + for item in items: + if isinstance(item, dict): + # Recursively deserialize children only if the key exists in response + children_list = None + if "children" in item: + children_list = _deserialize_info_list(item.get("children", [])) + info = Info( + children=children_list, + client_window=item.get("clientWindow"), + context=item.get("context"), + original_opener=item.get("originalOpener"), + url=item.get("url"), + user_context=item.get("userContext"), + parent=item.get("parent"), + ) + result.append(info) + return result if result else None + + +""" + code += "\n\n" + + # Generate EventConfig and _EventManager for modules with events + if self.events: + # Generate EventConfig dataclass + code += """@dataclass +class EventConfig: + \"\"\"Configuration for a BiDi event.\"\"\" + event_key: str + bidi_event: str + event_class: type + + +""" + + # Generate _EventManager class + code += """class _EventWrapper: + \"\"\"Wrapper to provide event_class attribute for WebSocketConnection callbacks.\"\"\" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + \"\"\"Deserialize event params into the wrapped Python dataclass. + + Args: + params: Raw BiDi event params with camelCase keys. + + Returns: + An instance of the dataclass, or the raw dict on failure. + \"\"\" + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, \"from_json\") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend([\"_\", char.lower()]) + else: + result.append(char) + return \"\".join(result) + + +class _EventManager: + \"\"\"Manages event subscriptions and callbacks.\"\"\" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + \"\"\"Subscribe to a BiDi event if not already subscribed.\"\"\" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get(\"subscription\") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + \"callbacks\": [], + \"subscription_id\": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + \"\"\"Unsubscribe from a BiDi event if no more callbacks exist.\"\"\" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry[\"callbacks\"]: + session = Session(self.conn) + sub_id = entry.get(\"subscription_id\") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event][\"callbacks\"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry[\"callbacks\"]: + entry[\"callbacks\"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + \"\"\"Clear all event handlers.\"\"\" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry[\"callbacks\"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get(\"subscription_id\") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() + + +""" + code += "\n\n" + + # Generate class + # Convert module name (camelCase or snake_case) to proper class name (PascalCase) + class_name = module_name_to_class_name(self.name) + code += f"class {class_name}:\n" + code += f' """WebDriver BiDi {self.name} module."""\n\n' + + # Add EVENT_CONFIGS dict if there are events + if self.events: + code += ( + " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined + ) + + if self.name == "script": + code += " def __init__(self, conn, driver=None) -> None:\n" + code += " self._conn = conn\n" + code += " self._driver = driver\n" + else: + code += " def __init__(self, conn) -> None:\n" + code += " self._conn = conn\n" + + # Initialize _event_manager if there are events + if self.events: + code += " self._event_manager = _EventManager(conn, self.EVENT_CONFIGS)\n" + + # Append extra init code from enhancements (e.g. self.intercepts = []) + for init_line in enhancements.get("extra_init_code", []): + code += f" {init_line}\n" + + code += "\n" + + # Generate command methods + exclude_methods = enhancements.get("exclude_methods", []) + + # Automatically exclude methods that are defined in extra_methods + # to prevent generating duplicates + if "extra_methods" in enhancements: + for extra_method in enhancements["extra_methods"]: + # Extract method name from "def method_name(" + match = re.search(r"def\s+(\w+)\s*\(", extra_method) + if match: + exclude_methods = list(exclude_methods) + [match.group(1)] + + if self.commands: + for command in self.commands: + # Get method-specific enhancements + # Convert command name to snake_case to match enhancement manifest keys + method_name_snake = command._camel_to_snake(command.name) + if method_name_snake in exclude_methods: + continue + method_enhancements = enhancements.get(method_name_snake, {}) + code += command.to_python_method(method_enhancements) + code += "\n" + elif not self.events and not enhancements.get("extra_methods", []): + code += " pass\n" + + # Emit extra methods from enhancement manifest + for extra_method in enhancements.get("extra_methods", []): + code += extra_method + code += "\n" + + # Add delegating event handler methods if events are present + if self.events: + code += """ + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + \"\"\"Add an event handler. + + Args: + event: The event to subscribe to. + callback: The callback function to execute on event. + contexts: The context IDs to subscribe to (optional). + + Returns: + The callback ID. + \"\"\" + return self._event_manager.add_event_handler(event, callback, contexts) + + def remove_event_handler(self, event: str, callback_id: int) -> None: + \"\"\"Remove an event handler. + + Args: + event: The event to unsubscribe from. + callback_id: The callback ID. + \"\"\" + return self._event_manager.remove_event_handler(event, callback_id) + + def clear_event_handlers(self) -> None: + \"\"\"Clear all event handlers.\"\"\" + return self._event_manager.clear_event_handlers() +""" + + # Generate event info type aliases AFTER the class definition + # This ensures all types are available when we create the aliases + if self.events: + code += "\n# Event Info Type Aliases\n" + # Check for explicit event_type_aliases in the enhancement manifest + event_type_aliases = enhancements.get("event_type_aliases", {}) + for event_def in self.events: + # Convert method name to user-friendly event name + method_parts = event_def.method.split(".") + if len(method_parts) == 2: + event_name = self._convert_method_to_event_name(method_parts[1]) + # Check if there's an explicit alias defined in the enhancement manifest + if event_name in event_type_aliases: + # Use the alias directly + type_name = event_type_aliases[event_name] + code += f"# Event: {event_def.method}\n" + code += f"{event_def.name} = {type_name}\n" + else: + # Fall back to the original behavior + code += event_def.to_python_dataclass() + code += "\n" + + # Now populate EVENT_CONFIGS after the aliases are defined + code += "\n# Populate EVENT_CONFIGS with event configuration mappings\n" + # Use globals() to look up types dynamically to handle missing types gracefully + code += "_globals = globals()\n" + code += f"{class_name}.EVENT_CONFIGS = {{\n" + for event_def in self.events: + # Convert method name to user-friendly event name + method_parts = event_def.method.split(".") + if len(method_parts) == 2: + event_name = self._convert_method_to_event_name(method_parts[1]) + # Try to get event class from globals, default to dict if not found + getter = f'_globals.get("{event_def.name}", dict)' + condition = f'_globals.get("{event_def.name}")' + event_class = f"{getter} if {condition} else dict" + + # Build the entry line and check if it exceeds 120 chars + single_line = ( + f' "{event_name}": ' + f'EventConfig("{event_name}", "{event_def.method}", {event_class}),' + ) + + if len(single_line) > 120: + # Break into multiple lines + code += f' "{event_name}": EventConfig(\n' + code += f' "{event_name}",\n' + code += f' "{event_def.method}",\n' + code += f" {event_class},\n" + code += " ),\n" + else: + code += single_line + "\n" + # Extra events not in the CDDL spec + for extra_evt in enhancements.get("extra_events", []): + ek = extra_evt["event_key"] + be = extra_evt["bidi_event"] + ec = extra_evt["event_class"] + code += f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),\n' + code += "}\n" + + # Check if field() is actually used in the generated code + # If so, add the field import after the dataclass import + if "field(" in code: + # Find where to insert the field import + # It should go after "from dataclasses import dataclass" line + dataclass_import_pattern = r"from dataclasses import dataclass\n" + if re.search(dataclass_import_pattern, code): + code = re.sub( + dataclass_import_pattern, + "from dataclasses import dataclass\nfrom dataclasses import field\n", + code, + count=1 + ) + elif "from dataclasses import" not in code: + # If there's no dataclasses import yet, add field import after typing + code = code.replace( + "from typing import Any\n", + "from typing import Any\nfrom dataclasses import field\n" + ) + + return code + + +class CddlParser: + """Parse CDDL specification files.""" + + def __init__(self, cddl_path: str): + """Initialize parser with CDDL file path.""" + self.cddl_path = Path(cddl_path) + self.content = "" + self.modules: dict[str, CddlModule] = {} + self.definitions: dict[str, str] = {} + self.event_names: set[str] = set() # Names of definitions that are events + self._read_file() + + def _read_file(self) -> None: + """Read and preprocess CDDL file.""" + if not self.cddl_path.exists(): + raise FileNotFoundError(f"CDDL file not found: {self.cddl_path}") + + with open(self.cddl_path, encoding="utf-8") as f: + self.content = f.read() + + logger.info(f"Loaded CDDL file: {self.cddl_path}") + + def parse(self) -> dict[str, CddlModule]: + """Parse CDDL content and return modules.""" + # Remove comments + content = self._remove_comments(self.content) + + # Extract all definitions + self._extract_definitions(content) + + # Extract event names from event union definitions + self._extract_event_names() + + # Extract type definitions by module + self._extract_types() + + # Extract event definitions by module + self._extract_events() + + # Extract command definitions by module + self._extract_commands() + + # If no modules found, create a default one from the filename + if not self.modules: + module_name = self.cddl_path.stem + default_module = CddlModule(name=module_name) + self.modules[module_name] = default_module + logger.warning(f"No modules found in CDDL, creating default: {module_name}") + + return self.modules + + def _remove_comments(self, content: str) -> str: + """Remove comments from CDDL content.""" + # CDDL uses ; for comments to end of line + lines = content.split("\n") + cleaned = [] + for line in lines: + if ";" in line and not line.strip().startswith(";"): + line = line[: line.index(";")] + elif line.strip().startswith(";"): + continue + cleaned.append(line) + return "\n".join(cleaned) + + def _extract_definitions(self, content: str) -> None: + """Extract CDDL definitions (type definitions, commands, etc.).""" + # Match pattern: Name = Definition + # Handles multiline definitions properly + pattern = r"(\w+(?:\.\w+)*)\s*=\s*(.+?)(?=\n\w+(?:\.\w+)?\s*=|\Z)" + + for match in re.finditer(pattern, content, re.DOTALL): + name = match.group(1).strip() + definition = match.group(2).strip() + self.definitions[name] = definition + logger.debug(f"Extracted definition: {name}") + + def _extract_event_names(self) -> None: + """Extract event names from event union definitions. + + Event union definitions follow pattern: + module.ModuleEvent = ( + module.EventName1 // + module.EventName2 // + ... + ) + """ + for def_name, def_content in self.definitions.items(): + # Check if this looks like an event union (name ends with "Event") and + # contains a module-qualified reference like "module.EventName". + # Handles both single-item (no //) and multi-item (// separated) unions. + if "Event" in def_name and re.search(r"\w+\.\w+", def_content): + # Extract event names from the union (works for single and multi-item) + event_refs = re.findall(r"(\w+\.\w+)", def_content) + for event_ref in event_refs: + self.event_names.add(event_ref) + logger.debug(f"Identified event: {event_ref} (from {def_name})") + + def _extract_types(self) -> None: + """Extract type definitions from parsed definitions.""" + # Type definitions follow pattern: module.TypeName = { field: type, ... } + # They have dots in the name and curly braces in the content + # But they DON'T have method: "..." pattern (which means it's not a command) + # Enums follow pattern: module.EnumName = "value1" / "value2" / ... + + for def_name, def_content in self.definitions.items(): + # Skip if not a namespaced name (e.g., skip "EmptyParams", "Extensible") + if "." not in def_name: + continue + + # Skip if it's a command (contains method: pattern) + if "method:" in def_content: + continue + + # Extract module.TypeName + if "." in def_name: + module_name, type_name = def_name.rsplit(".", 1) + + # Create module if not exists + if module_name not in self.modules: + self.modules[module_name] = CddlModule(name=module_name) + + # Check if this is an enum (string union with /) + if self._is_enum_definition(def_content): + # Extract enum values + values = self._extract_enum_values(def_content) + if values: + enum_def = CddlEnum( + module=module_name, + name=type_name, + values=values, + description=f"{type_name}", + ) + self.modules[module_name].enums.append(enum_def) + logger.debug( + f"Found enum: {def_name} with {len(values)} values" + ) + else: + # Extract fields from type definition + fields = self._extract_type_fields(def_content) + + if fields: # Only create type if it has fields + type_def = CddlTypeDefinition( + module=module_name, + name=type_name, + fields=fields, + description=f"{type_name}", + ) + self.modules[module_name].types.append(type_def) + logger.debug( + f"Found type: {def_name} with {len(fields)} fields" + ) + + def _is_enum_definition(self, definition: str) -> bool: + """Check if a definition is an enum (string union with /). + + Enums are defined as: "value1" / "value2" / "value3" + """ + # Clean whitespace + clean_def = definition.strip() + + # Must not have curly braces (that would be a type definition) + if "{" in clean_def or "}" in clean_def: + return False + + # Must contain the union operator / surrounded by quotes + # Pattern: "something" / "something_else" + return " / " in clean_def and '"' in clean_def + + def _extract_enum_values(self, enum_definition: str) -> list[str]: + """Extract individual values from an enum definition. + + Enums are defined as: "value1" / "value2" / "value3" + Can span multiple lines. + """ + values = [] + + # Clean the definition and extract quoted strings + # Split by / and extract quoted values + parts = enum_definition.split("/") + + for part in parts: + part = part.strip() + + # Extract quoted string - use search instead of match to find quotes anywhere + match = re.search(r'"([^"]*)"', part) + if match: + value = match.group(1) + values.append(value) + logger.debug(f"Extracted enum value: {value}") + + return values + + @staticmethod + def _normalize_cddl_type(field_type: str) -> str: + """Normalize a CDDL type expression to a simple Python-compatible form. + + Strips CDDL control operators (.ge, .le, .gt, .lt, .default, etc.) and + replaces interval/constraint expressions with their base types so that + the caller can safely check for nested struct syntax. + + Examples: + '(float .ge 0.0) .default 1.0' -> 'float' + '(float .ge 0.0) / null' -> 'float / null' + '(0.0...360.0) / null' -> 'float / null' + '-90.0..90.0' -> 'float' + 'float / null .default null' -> 'float / null' + """ + result = field_type + # Remove trailing .default annotations + result = re.sub(r"\s*\.default\s+\S+", "", result) + # Replace parenthesised constraint expressions: (baseType .operator ...) -> baseType + result = re.sub(r"\((\w+)\s+\.\w+[^)]*\)", r"\1", result) + # Replace parenthesised numeric interval types: (0.0...360.0) -> float + result = re.sub(r"\(-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?\)", "float", result) + # Replace bare numeric interval types: -90.0..90.0 -> float + result = re.sub(r"-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?", "float", result) + return result.strip() + + def _extract_type_fields(self, type_definition: str) -> dict[str, str]: + """Extract fields from a type definition block.""" + fields = {} + + # Remove outer braces + clean_def = type_definition.strip() + if clean_def.startswith("{"): + clean_def = clean_def[1:] + if clean_def.endswith("}"): + clean_def = clean_def[:-1] + + # Parse each line for field: type patterns + for line in clean_def.split("\n"): + line = line.strip() + if not line or "Extensible" in line or line.startswith("//"): + continue + + # Match pattern: [?] fieldName: type + match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + if not match: + # Try without optional marker + match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + + if match: + field_name = match.group(1).strip() + field_type = match.group(2).strip() + normalized_type = self._normalize_cddl_type(field_type) + + # Skip lines that are part of nested definitions + if "{" not in normalized_type and "(" not in normalized_type: + fields[field_name] = normalized_type + logger.debug(f"Extracted field {field_name}: {normalized_type}") + + return fields + + def _extract_events(self) -> None: + """Extract event definitions from parsed definitions. + + Events are definitions that: + 1. Are listed in an event union (e.g., BrowsingContextEvent) + 2. Have method: "..." and params: ... fields + + Event pattern: module.EventName = (method: "module.eventName", params: module.ParamType) + """ + # Find definitions that are in the event_names set + event_pattern = re.compile( + r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" + ) + + for def_name, def_content in self.definitions.items(): + # Skip if not identified as an event + if def_name not in self.event_names: + continue + + # Extract method and params + match = event_pattern.search(def_content) + if match: + method = match.group(1) # e.g., "browsingContext.contextCreated" + params_type = match.group(2) # e.g., "browsingContext.Info" + + # Extract module name from method + if "." in method: + module_name, _ = method.split(".", 1) + + # Create module if not exists + if module_name not in self.modules: + self.modules[module_name] = CddlModule(name=module_name) + + # Extract event name from definition name (e.g., browsingContext.ContextCreated) + _, event_name = def_name.rsplit(".", 1) + + # Create event + event = CddlEvent( + module=module_name, + name=event_name, + method=method, + params_type=params_type, + description=f"Event: {method}", + ) + + self.modules[module_name].events.append(event) + logger.debug( + f"Found event: {def_name} (method={method}, params={params_type})" + ) + + def _extract_commands(self) -> None: + """Extract command definitions from parsed definitions.""" + # Find command definitions that follow pattern: module.Command = (method: "...", params: ...) + command_pattern = re.compile( + r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)" + ) + + for def_name, def_content in self.definitions.items(): + # Skip definitions that are events (they share the same pattern) + if def_name in self.event_names: + continue + matches = list(command_pattern.finditer(def_content)) + if matches: + for match in matches: + method = match.group(1) # e.g., "session.new" + params_type = match.group(2) # e.g., "session.NewParameters" + + # Extract module name from method + if "." in method: + module_name, command_name = method.split(".", 1) + + # Create module if not exists + if module_name not in self.modules: + self.modules[module_name] = CddlModule(name=module_name) + + # Extract parameters and required parameters + params, required_params = self._extract_parameters_and_required( + params_type + ) + + # Create command + cmd = CddlCommand( + module=module_name, + name=command_name, + params=params, + required_params=required_params, + description=f"Execute {method}", + ) + + self.modules[module_name].commands.append(cmd) + logger.debug( + f"Found command: {method} with params {params_type}" + ) + + def _extract_parameters( + self, params_type: str, _seen: set[str] | None = None + ) -> dict[str, str]: + """Extract parameters from a parameter type definition. + + Handles both struct types ({...}) and top-level union types (TypeA / TypeB), + merging all fields from each alternative as optional parameters. + """ + params, _ = self._extract_parameters_and_required(params_type, _seen) + return params + + def _extract_parameters_and_required( + self, params_type: str, _seen: set[str] | None = None + ) -> tuple[dict[str, str], set[str]]: + """Extract parameters and track which are required from a parameter type definition. + + Returns: + Tuple of (params dict, required_params set) + """ + params = {} + required = set() + + if _seen is None: + _seen = set() + if params_type in _seen: + return params, required + _seen.add(params_type) + + if params_type not in self.definitions: + logger.debug(f"Parameter type not found: {params_type}") + return params, required + + definition = self.definitions[params_type] + + # Handle top-level type alias that is a union of other named types: + # e.g. session.UnsubscribeByAttributesRequest / session.UnsubscribeByIDRequest + # These definitions contain a single line with "/" separating type names + # (not the double-slash "//" used for command unions). + stripped = definition.strip() + if not stripped.startswith("{") and "/" in stripped and "//" not in stripped: + # Each token separated by "/" should be a named type reference + alternatives = [a.strip() for a in stripped.split("/") if a.strip()] + all_named = all(re.match(r"^[\w.]+$", a) for a in alternatives) + if all_named: + # For union types, collect parameters from all alternatives + # but treat them as optional since the caller only needs to pass one alternative + for alt_type in alternatives: + alt_params, _ = self._extract_parameters_and_required( + alt_type, _seen + ) + params.update(alt_params) + # Note: We intentionally DON'T add to required, since these are union alternatives + return params, required + + # Remove the outer curly braces and split by comma + # Then parse each line for key: type patterns + clean_def = stripped + if clean_def.startswith("{"): + clean_def = clean_def[1:] + if clean_def.endswith("}"): + clean_def = clean_def[:-1] + + # Split by newlines and process each line + for line in clean_def.split("\n"): + line = line.strip() + if not line or "Extensible" in line: + continue + + # Match pattern: [?] name: type + # Check if parameter has optional marker (?) + is_optional = line.startswith("?") + + # Using a simple pattern that handles optional prefix + match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + if not match: + # Try without optional marker + match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) + + if match: + param_name = match.group(1).strip() + param_type = match.group(2).strip() + normalized_type = self._normalize_cddl_type(param_type) + + # Skip lines that are part of nested definitions + if "{" not in normalized_type and "(" not in normalized_type: + params[param_name] = normalized_type + if not is_optional: + required.add(param_name) + logger.debug( + f"Extracted param {param_name}: {normalized_type} " + f"(required={not is_optional}) from {params_type}" + ) + + return params, required + + +def module_name_to_class_name(module_name: str) -> str: + """Convert module name to class name (PascalCase). + + Handles both camelCase (browsingContext) and snake_case (browsing_context). + """ + if "_" in module_name: + # Snake_case: browsing_context -> BrowsingContext + return "".join(word.capitalize() for word in module_name.split("_")) + else: + # CamelCase: browsingContext -> BrowsingContext + return module_name[0].upper() + module_name[1:] if module_name else "" + + +def module_name_to_filename(module_name: str) -> str: + """Convert module name to Python filename (snake_case). + + Handles both camelCase (browsingContext) and snake_case (browsing_context). + Special cases: + - browsingContext -> browsing_context + - webExtension -> webextension + """ + # Handle explicit mappings for known camelCase names + camel_to_snake_map = { + "browsingContext": "browsing_context", + "webExtension": "webextension", + } + + if module_name in camel_to_snake_map: + return camel_to_snake_map[module_name] + + if "_" in module_name: + # Already snake_case + return module_name + else: + # Convert camelCase to snake_case for other cases + # This handles cases like "myModuleName" -> "my_module_name" + import re + + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", module_name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + + +def generate_init_file(output_path: Path, modules: dict[str, CddlModule]) -> None: + """Generate __init__.py file for the module.""" + init_path = output_path / "__init__.py" + + code = f"""{SHARED_HEADER} + +from __future__ import annotations + +""" + + for module_name in sorted(modules.keys()): + class_name = module_name_to_class_name(module_name) + filename = module_name_to_filename(module_name) + code += f"from .{filename} import {class_name}\n" + + code += "\n__all__ = [\n" + for module_name in sorted(modules.keys()): + class_name = module_name_to_class_name(module_name) + code += f' "{class_name}",\n' + code += "]\n" + + with open(init_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {init_path}") + + +def generate_common_file(output_path: Path) -> None: + """Generate common.py file with shared utilities.""" + common_path = output_path / "common.py" + + code = ( + "# Licensed to the Software Freedom Conservancy (SFC) under one\n" + "# or more contributor license agreements. See the NOTICE file\n" + "# distributed with this work for additional information\n" + "# regarding copyright ownership. The SFC licenses this file\n" + "# to you under the Apache License, Version 2.0 (the\n" + '# "License"); you may not use this file except in compliance\n' + "# with the License. You may obtain a copy of the License at\n" + "#\n" + "# http://www.apache.org/licenses/LICENSE-2.0\n" + "#\n" + "# Unless required by applicable law or agreed to in writing,\n" + "# software distributed under the License is distributed on an\n" + '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' + "# KIND, either express or implied. See the License for the\n" + "# specific language governing permissions and limitations\n" + "# under the License.\n" + "\n" + '"""Common utilities for BiDi command construction."""\n' + "\n" + "from __future__ import annotations\n" + "\n" + "from collections.abc import Generator\n" + "from typing import Any\n" + "\n" + "\n" + "def command_builder(\n" + " method: str, params: dict[str, Any]\n" + ") -> Generator[dict[str, Any], Any, Any]:\n" + ' """Build a BiDi command generator.\n' + "\n" + " Args:\n" + ' method: The BiDi method name (e.g., "session.status", "browser.close")\n' + " params: The parameters for the command\n" + "\n" + " Yields:\n" + " A dictionary representing the BiDi command\n" + "\n" + " Returns:\n" + " The result from the BiDi command execution\n" + ' """\n' + ' result = yield {"method": method, "params": params}\n' + " return result\n" + ) + + with open(common_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {common_path}") + + +def generate_console_file(output_path: Path) -> None: + """Generate console.py file with Console enum helper.""" + console_path = output_path / "console.py" + + code = ( + "# Licensed to the Software Freedom Conservancy (SFC) under one\n" + "# or more contributor license agreements. See the NOTICE file\n" + "# distributed with this work for additional information\n" + "# regarding copyright ownership. The SFC licenses this file\n" + "# to you under the Apache License, Version 2.0 (the\n" + '# "License"); you may not use this file except in compliance\n' + "# with the License. You may obtain a copy of the License at\n" + "#\n" + "# http://www.apache.org/licenses/LICENSE-2.0\n" + "#\n" + "# Unless required by applicable law or agreed to in writing,\n" + "# software distributed under the License is distributed on an\n" + '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' + "# KIND, either express or implied. See the License for the\n" + "# specific language governing permissions and limitations\n" + "# under the License.\n" + "\n" + "from enum import Enum\n" + "\n" + "\n" + "class Console(Enum):\n" + ' ALL = "all"\n' + ' LOG = "log"\n' + ' ERROR = "error"\n' + ) + + with open(console_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {console_path}") + + +def generate_permissions_file(output_path: Path) -> None: + """Generate permissions.py file with permission-related classes.""" + permissions_path = output_path / "permissions.py" + + code = ( + "# Licensed to the Software Freedom Conservancy (SFC) under one\n" + "# or more contributor license agreements. See the NOTICE file\n" + "# distributed with this work for additional information\n" + "# regarding copyright ownership. The SFC licenses this file\n" + "# to you under the Apache License, Version 2.0 (the\n" + '# "License"); you may not use this file except in compliance\n' + "# with the License. You may obtain a copy of the License at\n" + "#\n" + "# http://www.apache.org/licenses/LICENSE-2.0\n" + "#\n" + "# Unless required by applicable law or agreed to in writing,\n" + "# software distributed under the License is distributed on an\n" + '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' + "# KIND, either express or implied. See the License for the\n" + "# specific language governing permissions and limitations\n" + "# under the License.\n" + "\n" + '"""WebDriver BiDi Permissions module."""\n' + "\n" + "from __future__ import annotations\n" + "\n" + "from enum import Enum\n" + "from typing import Any\n" + "\n" + "from .common import command_builder\n" + "\n" + '_VALID_PERMISSION_STATES = {"granted", "denied", "prompt"}\n' + "\n" + "\n" + "class PermissionState(str, Enum):\n" + ' """Permission state enumeration."""\n' + "\n" + ' GRANTED = "granted"\n' + ' DENIED = "denied"\n' + ' PROMPT = "prompt"\n' + "\n" + "\n" + "class PermissionDescriptor:\n" + ' """Descriptor for a permission."""\n' + "\n" + " def __init__(self, name: str) -> None:\n" + ' """Initialize a PermissionDescriptor.\n' + "\n" + " Args:\n" + " name: The name of the permission (e.g., 'geolocation', 'microphone', 'camera')\n" + ' """\n' + " self.name = name\n" + "\n" + " def __repr__(self) -> str:\n" + " return f\"PermissionDescriptor('{self.name}')\"\n" + "\n" + "\n" + "class Permissions:\n" + ' """WebDriver BiDi Permissions module."""\n' + "\n" + " def __init__(self, websocket_connection: Any) -> None:\n" + ' """Initialize the Permissions module.\n' + "\n" + " Args:\n" + " websocket_connection: The WebSocket connection for sending BiDi commands\n" + ' """\n' + " self._conn = websocket_connection\n" + "\n" + " def set_permission(\n" + " self,\n" + " descriptor: PermissionDescriptor | str,\n" + " state: PermissionState | str,\n" + " origin: str | None = None,\n" + " user_context: str | None = None,\n" + " ) -> None:\n" + ' """Set a permission for a given origin.\n' + "\n" + " Args:\n" + " descriptor: The permission descriptor or permission name as a string\n" + " state: The desired permission state\n" + " origin: The origin for which to set the permission\n" + " user_context: Optional user context ID to scope the permission\n" + "\n" + " Raises:\n" + " ValueError: If the state is not a valid permission state\n" + ' """\n' + " state_value = state.value if isinstance(state, PermissionState) else state\n" + " if state_value not in _VALID_PERMISSION_STATES:\n" + " raise ValueError(\n" + ' f"Invalid permission state: {state_value!r}. "\n' + ' f"Must be one of {sorted(_VALID_PERMISSION_STATES)}"\n' + " )\n" + "\n" + " if isinstance(descriptor, str):\n" + ' descriptor_dict = {"name": descriptor}\n' + " else:\n" + ' descriptor_dict = {"name": descriptor.name}\n' + "\n" + " params: dict[str, Any] = {\n" + ' "descriptor": descriptor_dict,\n' + ' "state": state_value,\n' + " }\n" + " if origin is not None:\n" + ' params["origin"] = origin\n' + " if user_context is not None:\n" + ' params["userContext"] = user_context\n' + "\n" + ' cmd = command_builder("permissions.setPermission", params)\n' + " self._conn.execute(cmd)\n" + ) + + with open(permissions_path, "w", encoding="utf-8") as f: + f.write(code) + + logger.info(f"Generated: {permissions_path}") + + +def main( + cddl_file: str, + output_dir: str, + spec_version: str = "1.0", + enhancements_manifest: str | None = None, +) -> None: + """Main entry point. + + Args: + cddl_file: Path to CDDL specification file + output_dir: Output directory for generated modules + spec_version: BiDi spec version + enhancements_manifest: Path to enhancement manifest Python file + """ + output_path = Path(output_dir).resolve() + output_path.mkdir(parents=True, exist_ok=True) + + logger.info(f"WebDriver BiDi Code Generator v{__version__}") + logger.info(f"Input CDDL: {cddl_file}") + logger.info(f"Output directory: {output_path}") + logger.info(f"Spec version: {spec_version}") + + # Load enhancement manifest + manifest = load_enhancements_manifest(enhancements_manifest) + if manifest: + logger.info(f"Loaded enhancement manifest from: {enhancements_manifest}") + + # Parse CDDL + parser = CddlParser(cddl_file) + modules = parser.parse() + + logger.info(f"Parsed {len(modules)} modules") + + # Clean up existing generated files + for file_path in output_path.glob("*.py"): + if file_path.name != "py.typed" and not file_path.name.startswith("_"): + file_path.unlink() + logger.debug(f"Removed: {file_path}") + + # Generate module files using snake_case filenames + for module_name, module in sorted(modules.items()): + filename = module_name_to_filename(module_name) + module_path = output_path / f"{filename}.py" + + # Get module-specific enhancements (merge with dataclass templates) + module_enhancements = manifest.get("enhancements", {}).get(module_name, {}) + + # Add dataclass methods and docstrings to the enhancement data for this module + full_module_enhancements = { + **module_enhancements, + "dataclass_methods": manifest.get("dataclass_methods", {}), + "method_docstrings": manifest.get("method_docstrings", {}), + } + + with open(module_path, "w", encoding="utf-8") as f: + f.write(module.generate_code(full_module_enhancements)) + logger.info(f"Generated: {module_path}") + + # Generate __init__.py + generate_init_file(output_path, modules) + + # Generate common.py + generate_common_file(output_path) + + # Generate permissions.py + generate_permissions_file(output_path) + + # Generate console.py + generate_console_file(output_path) + + # Create py.typed marker + py_typed_path = output_path / "py.typed" + py_typed_path.touch() + logger.info(f"Generated type marker: {py_typed_path}") + + logger.info("Code generation complete!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate Python WebDriver BiDi modules from CDDL specification" + ) + parser.add_argument( + "cddl_file", + help="Path to CDDL specification file", + ) + parser.add_argument( + "output_dir", + help="Output directory for generated Python modules", + ) + parser.add_argument( + "spec_version", + nargs="?", + default="1.0", + help="BiDi spec version (default: 1.0)", + ) + parser.add_argument( + "--enhancements-manifest", + default=None, + help="Path to enhancement manifest Python file (optional)", + ) + parser.add_argument( + "-v", + "--verbose", + action="store_true", + help="Enable verbose logging", + ) + + args = parser.parse_args() + + if args.verbose: + logging.getLogger("generate_bidi").setLevel(logging.DEBUG) + + try: + main( + args.cddl_file, + args.output_dir, + args.spec_version, + args.enhancements_manifest, + ) + sys.exit(0) + except Exception as e: + logger.error(f"Generation failed: {e}", exc_info=True) + sys.exit(1) diff --git a/py/private/BUILD.bazel b/py/private/BUILD.bazel index 8b02ac341a0dc..88acc9d2aba11 100644 --- a/py/private/BUILD.bazel +++ b/py/private/BUILD.bazel @@ -1,5 +1,10 @@ load("@rules_python//python:defs.bzl", "py_binary") +exports_files([ + "bidi_enhancements_manifest.py", + "cdp.py", +]) + py_binary( name = "untar", srcs = [ diff --git a/py/private/bidi_enhancements_manifest.py b/py/private/bidi_enhancements_manifest.py new file mode 100644 index 0000000000000..dcf464f425e9d --- /dev/null +++ b/py/private/bidi_enhancements_manifest.py @@ -0,0 +1,1687 @@ +""" +Enhancement manifest for BiDi code generation. + +This file defines custom enhancements applied to generated BiDi modules, +including custom dataclass methods, parameter validation/transformation, +response deserialization, and field extraction. + +All code must be compatible with Python 3.10+. +""" + +from __future__ import annotations + +from typing import Any + +# ============================================================================ +# Format Guide +# ============================================================================ +# Each module in ENHANCEMENTS specifies enhancement rules for methods: +# +# 'module_name': { +# 'method_name': { +# 'dataclass_methods': { # For dataclass enhancements +# 'ClassName': ['method1', 'method2', ...] +# }, +# 'preprocess': { # Pre-processing on parameters +# 'param_name': 'check_serialize_method' +# }, +# 'deserialize': { # Deserialize response to typed objects +# 'response_field': 'TypeName', +# }, +# 'extract_field': str, # Extract nested field from response +# 'extract_property': str, # Extract property from extracted items +# 'validate': str, # Validation function name +# 'transform': str, # Transformation function name +# } +# } +# ============================================================================ + +ENHANCEMENTS: dict[str, dict[str, Any]] = { + + "browser": { + # Dataclass custom methods + "__dataclass_methods__": { + "ClientWindowInfo": [ + "get_client_window", + "get_state", + "get_width", + "get_height", + "is_active", + "get_x", + "get_y", + ], + }, + # Method enhancements + "create_user_context": { + "preprocess": { + "proxy": "check_serialize_method", + "unhandled_prompt_behavior": "check_serialize_method", + }, + "extract_field": "userContext", + }, + "get_client_windows": { + "deserialize": { + "clientWindows": "ClientWindowInfo", + }, + }, + "get_user_contexts": { + "extract_field": "userContexts", + "extract_property": "userContext", + }, + "set_download_behavior": { + "params_override": { + "allowed": "bool", + "destination_folder": "str", + "userContexts": "[*browser.UserContext]", + }, + "validate": "validate_download_behavior", + "transform": { + "allowed": "allowed", + "destination_folder": "destination_folder", + "func": "transform_download_params", + "result_param": "download_behavior", + }, + }, + # Replace the auto-generated ClientWindowNamedState so we can add the + # convenience NORMAL constant. In the BiDi spec "normal" is the state + # represented by ClientWindowRectState, but exposing it here keeps the + # Python API consistent with the old ClientWindowState enum. + "exclude_types": ["ClientWindowNamedState"], + "extra_dataclasses": [ + '''class ClientWindowNamedState: + """Named states for a browser client window.""" + + FULLSCREEN = "fullscreen" + MAXIMIZED = "maximized" + MINIMIZED = "minimized" + NORMAL = "normal"''', + ], + # Override the generator-produced set_download_behavior so that + # downloadBehavior is never stripped by the generic None filter. + # The BiDi spec marks it as required (can be null, but must be present). + "extra_methods": [ + ''' def set_download_behavior( + self, + allowed: bool | None = None, + destination_folder: str | None = None, + user_contexts: list[Any] | None = None, + ): + """Set the download behavior for the browser. + + Args: + allowed: ``True`` to allow downloads, ``False`` to deny, or ``None`` + to reset to browser default (sends ``null`` to the protocol). + destination_folder: Destination folder for downloads. Required when + ``allowed=True``. Accepts a string or :class:`pathlib.Path`. + user_contexts: Optional list of user context IDs. + + Raises: + ValueError: If *allowed* is ``True`` and *destination_folder* is + omitted, or ``False`` and *destination_folder* is provided. + """ + validate_download_behavior( + allowed=allowed, + destination_folder=destination_folder, + user_contexts=user_contexts, + ) + download_behavior = transform_download_params(allowed, destination_folder) + # downloadBehavior is a REQUIRED field in the BiDi spec (can be null but + # must be present). Do NOT use a generic None-filter on it. + params: dict = {"downloadBehavior": download_behavior} + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("browser.setDownloadBehavior", params) + return self._conn.execute(cmd)''', + ''' def set_client_window_state( + self, + client_window: Any | None = None, + state: Any | None = None, + ): + """Set the client window state. + + Args: + client_window: The client window ID to apply the state to. + state: The window state to set. Can be one of: + - A string: "fullscreen", "maximized", "minimized", "normal" + - A ClientWindowRectState object with width, height, x, y + - A dict representing the state + + Raises: + ValueError: If client_window is not provided or state is invalid. + """ + if client_window is None: + raise ValueError("client_window is required") + if state is None: + raise ValueError("state is required") + + # Serialize ClientWindowRectState if needed + state_param = state + if hasattr(state, '__dataclass_fields__'): + # It's a dataclass, convert to dict + state_param = { + k: v for k, v in state.__dict__.items() + if v is not None + } + + params = { + "clientWindow": client_window, + "state": state_param, + } + cmd = command_builder("browser.setClientWindowState", params) + return self._conn.execute(cmd)''', + ], + }, + + "browsingContext": { + # Method enhancements + "create": { + "extract_field": "context", + }, + "get_tree": { + "extract_field": "contexts", + "deserialize": { + "contexts": "Info", + }, + }, + "capture_screenshot": { + "extract_field": "data", + "params_override": { + "context": "str", + "format": "ImageFormat", + "clip": "BoxClipRectangle", + "origin": "str", + }, + }, + "print": { + "extract_field": "data", + }, + "locate_nodes": { + "extract_field": "nodes", + "params_override": { + "context": "str", + "locator": "dict", + "serializationOptions": "dict", + "startNodes": "list", + "maxNodeCount": "int", + }, + }, + "set_viewport": { + "params_override": { + "context": "str", + "viewport": "dict", + "userContexts": "list", + "devicePixelRatio": "float", + }, + }, + # Non-CDDL download event dataclasses (Chromium-specific) + "extra_dataclasses": [ + '''@dataclass +class DownloadWillBeginParams: + """DownloadWillBeginParams.""" + + suggested_filename: str | None = None''', + '''@dataclass +class DownloadCanceledParams: + """DownloadCanceledParams.""" + + status: Any | None = None''', + '''@dataclass +class DownloadParams: + """DownloadParams - fields shared by all download end event variants.""" + + status: str | None = None + context: Any | None = None + navigation: Any | None = None + timestamp: Any | None = None + url: str | None = None + filepath: str | None = None''', + '''@dataclass +class DownloadEndParams: + """DownloadEndParams - params for browsingContext.downloadEnd event.""" + + download_params: "DownloadParams | None" = None + + @classmethod + def from_json(cls, params: dict) -> "DownloadEndParams": + """Deserialize from BiDi wire-level params dict.""" + dp = DownloadParams( + status=params.get("status"), + context=params.get("context"), + navigation=params.get("navigation"), + timestamp=params.get("timestamp"), + url=params.get("url"), + filepath=params.get("filepath"), + ) + return cls(download_params=dp)''', + ], + # Download events are now in the CDDL spec, so no extra_events needed + }, + + "log": { + # Make LogLevel an alias for Level so existing code using LogLevel works + "aliases": {"LogLevel": "Level"}, + # Replace the minimal CDDL-generated versions with richer ones that have from_json + "exclude_types": ["JavascriptLogEntry"], + "extra_dataclasses": [ + '''@dataclass +class ConsoleLogEntry: + """ConsoleLogEntry - a console log entry from the browser.""" + + type_: str | None = None + method: str | None = None + args: list | None = None + level: Any | None = None + text: Any | None = None + source: Any | None = None + timestamp: Any | None = None + stack_trace: Any | None = None + + @classmethod + def from_json(cls, params: dict) -> "ConsoleLogEntry": + """Deserialize from BiDi params dict.""" + return cls( + type_=params.get("type"), + method=params.get("method"), + args=params.get("args"), + level=params.get("level"), + text=params.get("text"), + source=params.get("source"), + timestamp=params.get("timestamp"), + stack_trace=params.get("stackTrace"), + )''', + '''@dataclass +class JavascriptLogEntry: + """JavascriptLogEntry - a JavaScript error log entry from the browser.""" + + type_: str | None = None + level: Any | None = None + text: Any | None = None + source: Any | None = None + timestamp: Any | None = None + stacktrace: Any | None = None + + @classmethod + def from_json(cls, params: dict) -> "JavascriptLogEntry": + """Deserialize from BiDi params dict.""" + return cls( + type_=params.get("type"), + level=params.get("level"), + text=params.get("text"), + source=params.get("source"), + timestamp=params.get("timestamp"), + stacktrace=params.get("stackTrace"), + )''', + ], + # Define Entry union type for log.entryAdded event deserialization + "extra_type_aliases": [ + "Entry = GenericLogEntry | ConsoleLogEntry | JavascriptLogEntry", + ], + "event_type_aliases": { + "entry_added": "Entry", + }, + }, + + "emulation": { + "extra_methods": [ + ''' def set_geolocation_override( + self, + coordinates=None, + error=None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute emulation.setGeolocationOverride. + + Sets or clears the geolocation override for specified browsing or user contexts. + + Args: + coordinates: A GeolocationCoordinates instance (or dict) to override the + position, or ``None`` to clear a previously-set override. + error: A GeolocationPositionError instance (or dict) to simulate a + position-unavailable error. Mutually exclusive with *coordinates*. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + params: dict[str, Any] = {} + if coordinates is not None: + if isinstance(coordinates, dict): + coords_dict = coordinates + else: + coords_dict = {} + if coordinates.latitude is not None: + coords_dict["latitude"] = coordinates.latitude + if coordinates.longitude is not None: + coords_dict["longitude"] = coordinates.longitude + if coordinates.accuracy is not None: + coords_dict["accuracy"] = coordinates.accuracy + if coordinates.altitude is not None: + coords_dict["altitude"] = coordinates.altitude + if coordinates.altitude_accuracy is not None: + coords_dict["altitudeAccuracy"] = coordinates.altitude_accuracy + if coordinates.heading is not None: + coords_dict["heading"] = coordinates.heading + if coordinates.speed is not None: + coords_dict["speed"] = coordinates.speed + params["coordinates"] = coords_dict + if error is not None: + if isinstance(error, dict): + params["error"] = error + else: + params["error"] = { + "type": error.type if error.type is not None else "positionUnavailable" + } + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setGeolocationOverride", params) + result = self._conn.execute(cmd) + return result''', + ''' def set_timezone_override( + self, + timezone=None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute emulation.setTimezoneOverride. + + Sets or clears the timezone override for specified browsing or user contexts. + Pass ``timezone=None`` (or omit it) to clear a previously-set override. + + Args: + timezone: IANA timezone string (e.g. ``"America/New_York"``) or ``None`` + to clear the override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + params: dict[str, Any] = {"timezone": timezone} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setTimezoneOverride", params) + return self._conn.execute(cmd)''', + ''' def set_scripting_enabled( + self, + enabled=None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute emulation.setScriptingEnabled. + + Enables or disables scripting for specified browsing or user contexts. + Pass ``enabled=None`` to restore the default behaviour. + + Args: + enabled: ``True`` to enable scripting, ``False`` to disable it, or + ``None`` to clear the override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + params: dict[str, Any] = {"enabled": enabled} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setScriptingEnabled", params) + return self._conn.execute(cmd)''', + ''' def set_user_agent_override( + self, + user_agent=None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute emulation.setUserAgentOverride. + + Overrides the User-Agent string for specified browsing or user contexts. + Pass ``user_agent=None`` to clear a previously-set override. + + Args: + user_agent: Custom User-Agent string, or ``None`` to clear the override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + params: dict[str, Any] = {"userAgent": user_agent} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setUserAgentOverride", params) + return self._conn.execute(cmd)''', + ''' def set_screen_orientation_override( + self, + screen_orientation=None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute emulation.setScreenOrientationOverride. + + Sets or clears the screen orientation override for specified browsing or + user contexts. + + Args: + screen_orientation: A :class:`ScreenOrientation` instance (or dict with + ``natural`` and ``type`` keys) to lock the orientation, or ``None`` + to clear a previously-set override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + if screen_orientation is None: + so_value = None + elif isinstance(screen_orientation, dict): + so_value = screen_orientation + else: + natural = screen_orientation.natural + orientation_type = screen_orientation.type + so_value = { + "natural": natural.lower() if isinstance(natural, str) else natural, + "type": orientation_type.lower() if isinstance(orientation_type, str) else orientation_type, + } + params: dict[str, Any] = {"screenOrientation": so_value} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setScreenOrientationOverride", params) + return self._conn.execute(cmd)''', + ''' def set_network_conditions( + self, + network_conditions=None, + offline: bool | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute emulation.setNetworkConditions. + + Sets or clears network condition emulation for specified browsing or user + contexts. + + Args: + network_conditions: A dict with the raw ``networkConditions`` value + (e.g. ``{"type": "offline"}``), or ``None`` to clear the override. + Mutually exclusive with *offline*. + offline: Convenience bool — ``True`` sets offline conditions, + ``False`` clears them (sends ``null``). When provided, this takes + precedence over *network_conditions*. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + if offline is not None: + nc_value = {"type": "offline"} if offline else None + else: + nc_value = network_conditions + params: dict[str, Any] = {"networkConditions": nc_value} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setNetworkConditions", params) + return self._conn.execute(cmd)''', + ], + }, + + "script": { + "extra_methods": [ + ''' def execute(self, function_declaration: str, *args, context_id: str | None = None) -> Any: + """Execute a function declaration in the browser context. + + Args: + function_declaration: The function as a string, e.g. ``"() => document.title"``. + *args: Optional Python values to pass as arguments to the function. + Each value is serialised to a BiDi ``LocalValue`` automatically. + Supported types: ``None``, ``bool``, ``int``, ``float`` + (including ``NaN`` and ``Infinity``), ``str``, ``list``, + ``dict``, and ``datetime.datetime``. + context_id: The browsing context ID to run in. Defaults to the + driver\'s current window handle when a driver was provided. + + Returns: + The inner RemoteValue result dict, or raises WebDriverException on exception. + """ + import math as _math + import datetime as _datetime + from selenium.common.exceptions import WebDriverException as _WebDriverException + + def _serialize_arg(value): + """Serialise a Python value to a BiDi LocalValue dict.""" + if value is None: + return {"type": "null"} + if isinstance(value, bool): + return {"type": "boolean", "value": value} + if isinstance(value, _datetime.datetime): + return {"type": "date", "value": value.isoformat()} + if isinstance(value, float): + if _math.isnan(value): + return {"type": "number", "value": "NaN"} + if _math.isinf(value): + return {"type": "number", "value": "Infinity" if value > 0 else "-Infinity"} + return {"type": "number", "value": value} + if isinstance(value, int): + _MAX_SAFE_INT = 9007199254740991 + if abs(value) > _MAX_SAFE_INT: + return {"type": "bigint", "value": str(value)} + return {"type": "number", "value": value} + if isinstance(value, str): + return {"type": "string", "value": value} + if isinstance(value, list): + return {"type": "array", "value": [_serialize_arg(v) for v in value]} + if isinstance(value, dict): + return {"type": "object", "value": [[str(k), _serialize_arg(v)] for k, v in value.items()]} + return value + + if context_id is None and self._driver is not None: + try: + context_id = self._driver.current_window_handle + except Exception: + pass + target = {"context": context_id} if context_id else {} + serialized_args = [_serialize_arg(a) for a in args] if args else None + raw = self.call_function( + function_declaration=function_declaration, + await_promise=True, + target=target, + arguments=serialized_args, + ) + if isinstance(raw, dict): + if raw.get("type") == "exception": + exc = raw.get("exceptionDetails", {}) + msg = exc.get("text", str(exc)) if isinstance(exc, dict) else str(exc) + raise _WebDriverException(msg) + if raw.get("type") == "success": + return raw.get("result") + return raw''', + ''' def _add_preload_script( + self, + function_declaration, + arguments=None, + contexts=None, + user_contexts=None, + sandbox=None, + ): + """Add a preload script with validation. + + Args: + function_declaration: The JS function to run on page load. + arguments: Optional list of BiDi arguments. + contexts: Optional list of browsing context IDs. + user_contexts: Optional list of user context IDs. + sandbox: Optional sandbox name. + + Returns: + script_id: The ID of the added preload script (str). + + Raises: + ValueError: If both contexts and user_contexts are specified. + """ + if contexts is not None and user_contexts is not None: + raise ValueError("Cannot specify both contexts and user_contexts") + result = self.add_preload_script( + function_declaration=function_declaration, + arguments=arguments, + contexts=contexts, + user_contexts=user_contexts, + sandbox=sandbox, + ) + if isinstance(result, dict): + return result.get("script") + return result''', + ''' def _remove_preload_script(self, script_id): + """Remove a preload script by ID. + + Args: + script_id: The ID of the preload script to remove. + """ + return self.remove_preload_script(script=script_id)''', + ''' def pin(self, function_declaration): + """Pin (add) a preload script that runs on every page load. + + Args: + function_declaration: The JS function to execute on page load. + + Returns: + script_id: The ID of the pinned script (str). + """ + return self._add_preload_script(function_declaration)''', + ''' def unpin(self, script_id): + """Unpin (remove) a previously pinned preload script. + + Args: + script_id: The ID returned by pin(). + """ + return self._remove_preload_script(script_id=script_id)''', + ''' def _evaluate( + self, + expression, + target, + await_promise, + result_ownership=None, + serialization_options=None, + user_activation=None, + ): + """Evaluate a script expression and return a structured result. + + Args: + expression: The JavaScript expression to evaluate. + target: A dict like {"context": } or {"realm": }. + await_promise: Whether to await a returned promise. + result_ownership: Optional result ownership setting. + serialization_options: Optional serialization options dict. + user_activation: Optional user activation flag. + + Returns: + An object with .realm, .result (dict or None), and .exception_details (or None). + """ + class _EvalResult: + def __init__(self2, realm, result, exception_details): + self2.realm = realm + self2.result = result + self2.exception_details = exception_details + + raw = self.evaluate( + expression=expression, + target=target, + await_promise=await_promise, + result_ownership=result_ownership, + serialization_options=serialization_options, + user_activation=user_activation, + ) + if isinstance(raw, dict): + realm = raw.get("realm") + if raw.get("type") == "exception": + exc = raw.get("exceptionDetails") + return _EvalResult(realm=realm, result=None, exception_details=exc) + return _EvalResult(realm=realm, result=raw.get("result"), exception_details=None) + return _EvalResult(realm=None, result=raw, exception_details=None)''', + ''' def _call_function( + self, + function_declaration, + await_promise, + target, + arguments=None, + result_ownership=None, + this=None, + user_activation=None, + serialization_options=None, + ): + """Call a function and return a structured result. + + Args: + function_declaration: The JS function string. + await_promise: Whether to await the return value. + target: A dict like {"context": }. + arguments: Optional list of BiDi arguments. + result_ownership: Optional result ownership. + this: Optional \'this\' binding. + user_activation: Optional user activation flag. + serialization_options: Optional serialization options dict. + + Returns: + An object with .result (dict or None) and .exception_details (or None). + """ + class _CallResult: + def __init__(self2, result, exception_details): + self2.result = result + self2.exception_details = exception_details + + raw = self.call_function( + function_declaration=function_declaration, + await_promise=await_promise, + target=target, + arguments=arguments, + result_ownership=result_ownership, + this=this, + user_activation=user_activation, + serialization_options=serialization_options, + ) + if isinstance(raw, dict): + if raw.get("type") == "exception": + exc = raw.get("exceptionDetails") + return _CallResult(result=None, exception_details=exc) + if raw.get("type") == "success": + return _CallResult(result=raw.get("result"), exception_details=None) + return _CallResult(result=raw, exception_details=None)''', + ''' def _get_realms(self, context=None, type=None): + """Get all realms, optionally filtered by context and type. + + Args: + context: Optional browsing context ID to filter by. + type: Optional realm type string to filter by (e.g. RealmType.WINDOW). + + Returns: + List of realm info objects with .realm, .origin, .type, .context attributes. + """ + class _RealmInfo: + def __init__(self2, realm, origin, type_, context): + self2.realm = realm + self2.origin = origin + self2.type = type_ + self2.context = context + + raw = self.get_realms(context=context, type=type) + realms_list = raw.get("realms", []) if isinstance(raw, dict) else [] + result = [] + for r in realms_list: + if isinstance(r, dict): + result.append(_RealmInfo( + realm=r.get("realm"), + origin=r.get("origin"), + type_=r.get("type"), + context=r.get("context"), + )) + return result''', + ''' def _disown(self, handles, target): + """Disown handles in a browsing context. + + Args: + handles: List of handle strings to disown. + target: A dict like {"context": }. + """ + return self.disown(handles=handles, target=target)''', + ''' def _subscribe_log_entry(self, callback, entry_type_filter=None): + """Subscribe to log.entryAdded BiDi events with optional type filtering.""" + import threading as _threading + from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod + + bidi_event = "log.entryAdded" + + if not hasattr(self, "_log_subscriptions"): + self._log_subscriptions = {} + self._log_lock = _threading.Lock() + + def _deserialize(params): + t = params.get("type") if isinstance(params, dict) else None + if t == "console": + cls = getattr(_log_mod, "ConsoleLogEntry", None) + if cls is not None and hasattr(cls, "from_json"): + try: + return cls.from_json(params) + except Exception: + pass + elif t == "javascript": + cls = getattr(_log_mod, "JavascriptLogEntry", None) + if cls is not None and hasattr(cls, "from_json"): + try: + return cls.from_json(params) + except Exception: + pass + return params + + def _wrapped(raw): + entry = _deserialize(raw) + if entry_type_filter is None: + callback(entry) + else: + t = getattr(entry, "type_", None) or ( + entry.get("type") if isinstance(entry, dict) else None + ) + if t == entry_type_filter: + callback(entry) + + class _BidiRef: + event_class = bidi_event + + def from_json(self2, p): + return p + + _wrapper = _BidiRef() + callback_id = self._conn.add_callback(_wrapper, _wrapped) + with self._log_lock: + if bidi_event not in self._log_subscriptions: + session = _Session(self._conn) + result = session.subscribe([bidi_event]) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self._log_subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + self._log_subscriptions[bidi_event]["callbacks"].append(callback_id) + return callback_id''', + ''' def _unsubscribe_log_entry(self, callback_id): + """Unsubscribe a log entry callback by ID.""" + from selenium.webdriver.common.bidi.session import Session as _Session + + bidi_event = "log.entryAdded" + if not hasattr(self, "_log_subscriptions"): + return + + class _BidiRef: + event_class = bidi_event + + def from_json(self2, p): + return p + + _wrapper = _BidiRef() + self._conn.remove_callback(_wrapper, callback_id) + with self._log_lock: + entry = self._log_subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + if entry is not None and not entry["callbacks"]: + session = _Session(self._conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self._log_subscriptions[bidi_event]''', + ''' def add_console_message_handler(self, callback: Callable) -> int: + """Add a handler for console log messages (log.entryAdded type=console). + + Args: + callback: Function called with a ConsoleLogEntry on each console message. + + Returns: + callback_id for use with remove_console_message_handler. + """ + return self._subscribe_log_entry(callback, entry_type_filter="console")''', + ''' def remove_console_message_handler(self, callback_id: int) -> None: + """Remove a console message handler by callback ID.""" + self._unsubscribe_log_entry(callback_id)''', + ''' def add_javascript_error_handler(self, callback: Callable) -> int: + """Add a handler for JavaScript error log messages (log.entryAdded type=javascript). + + Args: + callback: Function called with a JavascriptLogEntry on each JS error. + + Returns: + callback_id for use with remove_javascript_error_handler. + """ + return self._subscribe_log_entry(callback, entry_type_filter="javascript")''', + ''' def remove_javascript_error_handler(self, callback_id: int) -> None: + """Remove a JavaScript error handler by callback ID.""" + self._unsubscribe_log_entry(callback_id)''', + ], + }, + + "network": { + # Initialize intercepts tracking list and per-handler intercept map + "extra_init_code": [ + "self.intercepts: list[Any] = []", + "self._handler_intercepts: dict[str, Any] = {}", + ], + # Request class wraps a beforeRequestSent event params and provides actions + "extra_dataclasses": [ + '''class BytesValue: + """A string or base64-encoded bytes value used in cookie operations. + + This corresponds to network.BytesValue in the WebDriver BiDi specification, + wrapping either a plain string or a base64-encoded binary value. + """ + + TYPE_STRING = "string" + TYPE_BASE64 = "base64" + + def __init__(self, type: Any | None, value: Any | None) -> None: + self.type = type + self.value = value + + def to_bidi_dict(self) -> dict: + return {"type": self.type, "value": self.value}''', + '''class Request: + """Wraps a BiDi network request event params and provides request action methods.""" + + def __init__(self, conn, params): + self._conn = conn + self._params = params if isinstance(params, dict) else {} + req = self._params.get("request", {}) or {} + self.url = req.get("url", "") + self._request_id = req.get("request") + + def continue_request(self, **kwargs): + """Continue the intercepted request.""" + from selenium.webdriver.common.bidi.common import command_builder as _cb + + params = {"request": self._request_id} + params.update(kwargs) + self._conn.execute(_cb("network.continueRequest", params))''', + ], + # Add before_request event (maps to network.beforeRequestSent) + "extra_events": [ + { + "event_key": "before_request", + "bidi_event": "network.beforeRequestSent", + "event_class": "dict", + }, + ], + "extra_methods": [ + ''' def _add_intercept(self, phases=None, url_patterns=None): + """Add a low-level network intercept. + + Args: + phases: list of intercept phases (default: ["beforeRequestSent"]) + url_patterns: optional URL patterns to filter + + Returns: + dict with "intercept" key containing the intercept ID + """ + from selenium.webdriver.common.bidi.common import command_builder as _cb + + if phases is None: + phases = ["beforeRequestSent"] + params = {"phases": phases} + if url_patterns: + params["urlPatterns"] = url_patterns + result = self._conn.execute(_cb("network.addIntercept", params)) + if result: + intercept_id = result.get("intercept") + if intercept_id and intercept_id not in self.intercepts: + self.intercepts.append(intercept_id) + return result''', + ''' def _remove_intercept(self, intercept_id): + """Remove a low-level network intercept.""" + from selenium.webdriver.common.bidi.common import command_builder as _cb + + self._conn.execute(_cb("network.removeIntercept", {"intercept": intercept_id})) + if intercept_id in self.intercepts: + self.intercepts.remove(intercept_id)''', + ''' def add_request_handler(self, event, callback, url_patterns=None): + """Add a handler for network requests at the specified phase. + + Args: + event: Event name, e.g. ``"before_request"``. + callback: Callable receiving a :class:`Request` instance. + url_patterns: optional list of URL pattern dicts to filter. + + Returns: + callback_id int for later removal via remove_request_handler. + """ + phase_map = { + "before_request": "beforeRequestSent", + "before_request_sent": "beforeRequestSent", + "response_started": "responseStarted", + "auth_required": "authRequired", + } + phase = phase_map.get(event, "beforeRequestSent") + intercept_result = self._add_intercept(phases=[phase], url_patterns=url_patterns) + intercept_id = intercept_result.get("intercept") if intercept_result else None + + def _request_callback(params): + raw = ( + params + if isinstance(params, dict) + else (params.__dict__ if hasattr(params, "__dict__") else {}) + ) + request = Request(self._conn, raw) + callback(request) + + callback_id = self.add_event_handler(event, _request_callback) + if intercept_id: + self._handler_intercepts[callback_id] = intercept_id + return callback_id''', + ''' def remove_request_handler(self, event, callback_id): + """Remove a network request handler and its associated network intercept. + + Args: + event: The event name used when adding the handler. + callback_id: The int returned by add_request_handler. + """ + self.remove_event_handler(event, callback_id) + intercept_id = self._handler_intercepts.pop(callback_id, None) + if intercept_id: + self._remove_intercept(intercept_id)''', + ''' def clear_request_handlers(self): + """Clear all request handlers and remove all tracked intercepts.""" + self.clear_event_handlers() + for intercept_id in list(self.intercepts): + self._remove_intercept(intercept_id)''', + ''' def add_auth_handler(self, username, password): + """Add an auth handler that automatically provides credentials. + + Args: + username: The username for basic authentication. + password: The password for basic authentication. + + Returns: + callback_id int for later removal via remove_auth_handler. + """ + from selenium.webdriver.common.bidi.common import command_builder as _cb + + # Set up network intercept for authRequired phase + intercept_result = self._add_intercept(phases=["authRequired"]) + intercept_id = intercept_result.get("intercept") if intercept_result else None + + def _auth_callback(params): + raw = ( + params + if isinstance(params, dict) + else (params.__dict__ if hasattr(params, "__dict__") else {}) + ) + request_id = ( + raw.get("request", {}).get("request") + if isinstance(raw, dict) + else None + ) + if request_id: + self._conn.execute( + _cb( + "network.continueWithAuth", + { + "request": request_id, + "action": "provideCredentials", + "credentials": { + "type": "password", + "username": username, + "password": password, + }, + }, + ) + ) + + callback_id = self.add_event_handler("auth_required", _auth_callback) + if intercept_id: + self._handler_intercepts[callback_id] = intercept_id + return callback_id''', + ''' def remove_auth_handler(self, callback_id): + """Remove an auth handler by callback ID and its associated network intercept. + + Args: + callback_id: The handler ID returned by add_auth_handler. + """ + self.remove_event_handler("auth_required", callback_id) + intercept_id = self._handler_intercepts.pop(callback_id, None) + if intercept_id: + self._remove_intercept(intercept_id)''', + ], + }, + + "storage": { + # Exclude auto-generated dataclasses that need custom to_bidi_dict() + # for JSON-over-WebSocket serialization, or custom constructors. + "exclude_types": [ + "CookieFilter", + "PartialCookie", + "BrowsingContextPartitionDescriptor", + "StorageKeyPartitionDescriptor", + ], + "extra_dataclasses": [ + # Re-export network types used in cookie operations so they can be + # imported from selenium.webdriver.common.bidi.storage alongside + # the storage-specific classes. + '''class BytesValue: + """A string or base64-encoded bytes value used in cookie operations. + + This corresponds to network.BytesValue in the WebDriver BiDi specification, + wrapping either a plain string or a base64-encoded binary value. + """ + + TYPE_STRING = "string" + TYPE_BASE64 = "base64" + + def __init__(self, type: Any | None, value: Any | None) -> None: + self.type = type + self.value = value + + def to_bidi_dict(self) -> dict: + return {"type": self.type, "value": self.value}''', + '''class SameSite: + """SameSite cookie attribute values.""" + + STRICT = "strict" + LAX = "lax" + NONE = "none" + DEFAULT = "default"''', + # Helper: cookie object returned inside a GetCookiesResult response + '''@dataclass +class StorageCookie: + """A cookie object returned by storage.getCookies.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + path: str | None = None + size: Any | None = None + http_only: bool | None = None + secure: bool | None = None + same_site: Any | None = None + expiry: Any | None = None + + @classmethod + def from_bidi_dict(cls, raw: dict) -> "StorageCookie": + """Deserialize a wire-level cookie dict to a StorageCookie.""" + value_raw = raw.get("value") + if isinstance(value_raw, dict): + value: Any = BytesValue(value_raw.get("type"), value_raw.get("value")) + else: + value = value_raw + return cls( + name=raw.get("name"), + value=value, + domain=raw.get("domain"), + path=raw.get("path"), + size=raw.get("size"), + http_only=raw.get("httpOnly"), + secure=raw.get("secure"), + same_site=raw.get("sameSite"), + expiry=raw.get("expiry"), + )''', + # Custom CookieFilter with camelCase serialization + '''@dataclass +class CookieFilter: + """CookieFilter.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + path: str | None = None + size: Any | None = None + http_only: bool | None = None + secure: bool | None = None + same_site: Any | None = None + expiry: Any | None = None + + def to_bidi_dict(self) -> dict: + """Serialize to the BiDi wire-protocol dict.""" + result: dict = {} + if self.name is not None: + result["name"] = self.name + if self.value is not None: + result["value"] = self.value.to_bidi_dict() if hasattr(self.value, "to_bidi_dict") else self.value + if self.domain is not None: + result["domain"] = self.domain + if self.path is not None: + result["path"] = self.path + if self.size is not None: + result["size"] = self.size + if self.http_only is not None: + result["httpOnly"] = self.http_only + if self.secure is not None: + result["secure"] = self.secure + if self.same_site is not None: + result["sameSite"] = self.same_site + if self.expiry is not None: + result["expiry"] = self.expiry + return result''', + # Custom PartialCookie with camelCase serialization + '''@dataclass +class PartialCookie: + """PartialCookie.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + path: str | None = None + http_only: bool | None = None + secure: bool | None = None + same_site: Any | None = None + expiry: Any | None = None + + def to_bidi_dict(self) -> dict: + """Serialize to the BiDi wire-protocol dict.""" + result: dict = {} + if self.name is not None: + result["name"] = self.name + if self.value is not None: + result["value"] = self.value.to_bidi_dict() if hasattr(self.value, "to_bidi_dict") else self.value + if self.domain is not None: + result["domain"] = self.domain + if self.path is not None: + result["path"] = self.path + if self.http_only is not None: + result["httpOnly"] = self.http_only + if self.secure is not None: + result["secure"] = self.secure + if self.same_site is not None: + result["sameSite"] = self.same_site + if self.expiry is not None: + result["expiry"] = self.expiry + return result''', + # BrowsingContextPartitionDescriptor: first positional arg is *context* + # (the auto-generated dataclass had `type` first, breaking positional + # usage like BrowsingContextPartitionDescriptor(driver.current_window_handle)) + '''class BrowsingContextPartitionDescriptor: + """BrowsingContextPartitionDescriptor. + + The first positional argument is *context* (a browsing-context ID / window + handle), mirroring how the class is used throughout the test suite: + ``BrowsingContextPartitionDescriptor(driver.current_window_handle)``. + """ + + def __init__(self, context: Any = None, type: str = "context") -> None: + self.context = context + self.type = type + + def to_bidi_dict(self) -> dict: + return {"type": "context", "context": self.context}''', + # StorageKeyPartitionDescriptor with camelCase serialization + '''@dataclass +class StorageKeyPartitionDescriptor: + """StorageKeyPartitionDescriptor.""" + + type: Any | None = "storageKey" + user_context: str | None = None + source_origin: str | None = None + + def to_bidi_dict(self) -> dict: + """Serialize to the BiDi wire-protocol dict.""" + result: dict = {"type": "storageKey"} + if self.user_context is not None: + result["userContext"] = self.user_context + if self.source_origin is not None: + result["sourceOrigin"] = self.source_origin + return result''', + ], + # Override the generated Storage class methods (Python's last-definition- + # wins semantics means these extra_methods shadow the generated ones). + "extra_methods": [ + ''' def get_cookies(self, filter=None, partition=None): + """Execute storage.getCookies and return a GetCookiesResult.""" + if filter and hasattr(filter, "to_bidi_dict"): + filter = filter.to_bidi_dict() + if partition and hasattr(partition, "to_bidi_dict"): + partition = partition.to_bidi_dict() + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.getCookies", params) + result = self._conn.execute(cmd) + if result and "cookies" in result: + cookies = [ + StorageCookie.from_bidi_dict(c) + for c in result.get("cookies", []) + if isinstance(c, dict) + ] + pk_raw = result.get("partitionKey") + pk = ( + PartitionKey( + user_context=pk_raw.get("userContext"), + source_origin=pk_raw.get("sourceOrigin"), + ) + if isinstance(pk_raw, dict) + else None + ) + return GetCookiesResult(cookies=cookies, partition_key=pk) + return GetCookiesResult(cookies=[], partition_key=None)''', + ''' def set_cookie(self, cookie=None, partition=None): + """Execute storage.setCookie.""" + if cookie and hasattr(cookie, "to_bidi_dict"): + cookie = cookie.to_bidi_dict() + if partition and hasattr(partition, "to_bidi_dict"): + partition = partition.to_bidi_dict() + params = { + "cookie": cookie, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.setCookie", params) + result = self._conn.execute(cmd) + return result''', + ''' def delete_cookies(self, filter=None, partition=None): + """Execute storage.deleteCookies.""" + if filter and hasattr(filter, "to_bidi_dict"): + filter = filter.to_bidi_dict() + if partition and hasattr(partition, "to_bidi_dict"): + partition = partition.to_bidi_dict() + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.deleteCookies", params) + result = self._conn.execute(cmd) + return result''', + ], + }, + + "session": { + # Override UserPromptHandler to add to_bidi_dict() for JSON serialization + "exclude_types": ["UserPromptHandler"], + "extra_dataclasses": [ + '''@dataclass +class UserPromptHandler: + """UserPromptHandler.""" + + alert: Any | None = None + before_unload: Any | None = None + confirm: Any | None = None + default: Any | None = None + file: Any | None = None + prompt: Any | None = None + + def to_bidi_dict(self) -> dict: + """Convert to BiDi protocol dict with camelCase keys.""" + result = {} + if self.alert is not None: + result["alert"] = self.alert + if self.before_unload is not None: + result["beforeUnload"] = self.before_unload + if self.confirm is not None: + result["confirm"] = self.confirm + if self.default is not None: + result["default"] = self.default + if self.file is not None: + result["file"] = self.file + if self.prompt is not None: + result["prompt"] = self.prompt + return result''', + ], + }, + + "webExtension": { + # Suppress the raw generated stubs; hand-written versions follow below + "exclude_methods": ["install", "uninstall"], + "extra_methods": [ + ''' def install( + self, + path: str | None = None, + archive_path: str | None = None, + base64_value: str | None = None, + ): + """Install a web extension. + + Exactly one of the three keyword arguments must be provided. + + Args: + path: Directory path to an unpacked extension (also accepted for + signed ``.xpi`` / ``.crx`` archive files on Firefox). + archive_path: File-system path to a packed extension archive. + base64_value: Base64-encoded extension archive string. + + Returns: + The raw result dict from the BiDi ``webExtension.install`` command + (contains at least an ``"extension"`` key with the extension ID). + + Raises: + ValueError: If more than one, or none, of the arguments is provided. + """ + provided = [ + k for k, v in { + "path": path, "archive_path": archive_path, "base64_value": base64_value, + }.items() if v is not None + ] + if len(provided) != 1: + raise ValueError( + f"Exactly one of path, archive_path, or base64_value must be provided; got: {provided}" + ) + if path is not None: + extension_data = {"type": "path", "path": path} + elif archive_path is not None: + extension_data = {"type": "archivePath", "path": archive_path} + else: + assert base64_value is not None + extension_data = {"type": "base64", "value": base64_value} + params = {"extensionData": extension_data} + cmd = command_builder("webExtension.install", params) + return self._conn.execute(cmd)''', + ''' def uninstall(self, extension: str | dict): + """Uninstall a web extension. + + Args: + extension: Either the extension ID string returned by ``install``, + or the full result dict returned by ``install`` (the + ``"extension"`` value is extracted automatically). + + Raises: + ValueError: If extension is not provided or is None. + """ + if isinstance(extension, dict): + extension_id: Any = extension.get("extension") + else: + extension_id = extension + + if extension_id is None: + raise ValueError("extension parameter is required") + + params = {"extension": extension_id} + cmd = command_builder("webExtension.uninstall", params) + return self._conn.execute(cmd)''', + ], + }, + + "input": { + # FileDialogInfo needs from_json for event deserialization + "exclude_types": ["FileDialogInfo", "PointerMoveAction", "PointerDownAction"], + "extra_dataclasses": [ + '''@dataclass +class FileDialogInfo: + """FileDialogInfo - parameters for the input.fileDialogOpened event.""" + + context: Any | None = None + element: Any | None = None + multiple: bool | None = None + + @classmethod + def from_json(cls, params: dict) -> "FileDialogInfo": + """Deserialize event params into FileDialogInfo.""" + return cls( + context=params.get("context"), + element=params.get("element"), + multiple=params.get("multiple"), + )''', + '''@dataclass +class PointerMoveAction: + """PointerMoveAction.""" + + type: str = field(default="pointerMove", init=False) + x: Any | None = None + y: Any | None = None + duration: Any | None = None + origin: Any | None = None + properties: Any | None = None''', + '''@dataclass +class PointerDownAction: + """PointerDownAction.""" + + type: str = field(default="pointerDown", init=False) + button: Any | None = None + properties: Any | None = None''', + ], + "extra_methods": [ + ''' def add_file_dialog_handler(self, callback) -> int: + """Subscribe to the input.fileDialogOpened event. + + Args: + callback: Callable invoked with a FileDialogInfo when a file dialog opens. + + Returns: + A handler ID that can be passed to remove_file_dialog_handler. + """ + return self._event_manager.add_event_handler("file_dialog_opened", callback) + + def remove_file_dialog_handler(self, handler_id: int) -> None: + """Unsubscribe a previously registered file dialog event handler. + + Args: + handler_id: The handler ID returned by add_file_dialog_handler. + """ + return self._event_manager.remove_event_handler("file_dialog_opened", handler_id)''', + ], + }, +} + + +# ============================================================================ +# Pre-processing Functions +# ============================================================================ + + +def check_serialize_method(obj: Any) -> Any: + """Check if object has to_bidi_dict() method and use it for serialization.""" + if obj and hasattr(obj, "to_bidi_dict"): + return obj.to_bidi_dict() + return obj + + +# ============================================================================ +# Validation Functions +# ============================================================================ + + +def validate_download_behavior( + allowed: bool | None, + destination_folder: str | None, + user_contexts: Any | None = None, +) -> None: + """Validate download behavior parameters. + + Args: + allowed: Whether downloads are allowed + destination_folder: Destination folder for downloads + user_contexts: Optional list of user contexts (ignored for validation) + + Raises: + ValueError: If parameters are invalid + """ + if allowed is True and not destination_folder: + raise ValueError("destination_folder is required when allowed=True") + if allowed is False and destination_folder: + raise ValueError("destination_folder should not be provided when allowed=False") + + +# ============================================================================ +# Transformation Functions +# ============================================================================ + + +def transform_download_params( + allowed: bool | None, + destination_folder: str | None, +) -> dict[str, Any]: + """Transform download parameters into download_behavior object. + + Args: + allowed: Whether downloads are allowed + destination_folder: Destination folder for downloads + + Returns: + Dictionary representing the download_behavior object, or None if allowed is None + """ + if allowed is True: + return { + "type": "allowed", + # Convert pathlib.Path (or any path-like) to str so the BiDi + # protocol always receives a plain JSON string. + "destinationFolder": ( + str(destination_folder) if destination_folder is not None else None + ), + } + elif allowed is False: + return {"type": "denied"} + else: # None — reset to browser default (sent as JSON null) + return None + + +# ============================================================================ +# Dataclass Method Templates +# ============================================================================ + +DATACLASS_METHOD_TEMPLATES: dict[str, dict[str, str]] = { + "ClientWindowInfo": { + "get_client_window": "return self.client_window", + "get_state": "return self.state", + "get_width": "return self.width", + "get_height": "return self.height", + "is_active": "return self.active", + "get_x": "return self.x", + "get_y": "return self.y", + }, + "BrowsingContext": { + "add_event_handler": "_add_event_handler_impl", + "remove_event_handler": "_remove_event_handler_impl", + }, +} + +DATACLASS_METHOD_DOCSTRINGS: dict[str, dict[str, str]] = { + "ClientWindowInfo": { + "get_client_window": "Get the client window ID.", + "get_state": "Get the client window state.", + "get_width": "Get the client window width.", + "get_height": "Get the client window height.", + "is_active": "Check if the client window is active.", + "get_x": "Get the client window X position.", + "get_y": "Get the client window Y position.", + }, + "BrowsingContext": { + "add_event_handler": "Add an event handler for browsing context events.", + "remove_event_handler": "Remove an event handler by callback ID.", + }, +} + +# ============================================================================ +# Event Handler Support for BrowsingContext +# ============================================================================ + + +def _add_event_handler( + self, + event_name: str, + callback: callable, + contexts: list[str] | None = None, +) -> str: + """Add an event handler for a browsing context event. + + Supported events: + - 'context_created' + - 'context_destroyed' + - 'navigation_started' + - 'navigation_committed' + - 'navigation_failed' + - 'dom_content_loaded' + - 'load' + - 'fragment_navigated' + - 'user_prompt_opened' + - 'user_prompt_closed' + - 'download_will_begin' + - 'download_end' + - 'history_updated' + + Args: + self: The module instance this handler is bound to. + event_name: The name of the event to subscribe to + callback: Callback function to invoke when event occurs + contexts: Optional list of context IDs to limit event subscription + + Returns: + A callback ID that can be used to unsubscribe the handler + """ + if not hasattr(self, "_event_handlers"): + self._event_handlers = {} + self._event_callback_id_counter = 0 + + # Generate unique callback ID + self._event_callback_id_counter += 1 + callback_id = f"callback_{self._event_callback_id_counter}" + + # Store the handler + self._event_handlers[callback_id] = { + "event": event_name, + "callback": callback, + "contexts": contexts, + } + + # Subscribe via the driver's event listening mechanism + if hasattr(self._driver, "_subscribe_event"): + self._driver._subscribe_event(event_name, callback, contexts) + + return callback_id + + +def _remove_event_handler( + self, + callback_id: str, +) -> None: + """Remove an event handler by its callback ID. + + Args: + self: The module instance this handler is bound to. + callback_id: The callback ID returned from add_event_handler + """ + if not hasattr(self, "_event_handlers"): + return + + if callback_id in self._event_handlers: + handler_info = self._event_handlers[callback_id] + + # Unsubscribe from the driver + if hasattr(self._driver, "_unsubscribe_event"): + self._driver._unsubscribe_event( + handler_info["event"], + handler_info["callback"], + handler_info["contexts"], + ) + + del self._event_handlers[callback_id] diff --git a/py/private/cdp.py b/py/private/cdp.py new file mode 100644 index 0000000000000..b097762fe50cd --- /dev/null +++ b/py/private/cdp.py @@ -0,0 +1,515 @@ +# The MIT License(MIT) +# +# Copyright(c) 2018 Hyperion Gray +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files(the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# +# This code comes from https://github.com/HyperionGray/trio-chrome-devtools-protocol/tree/master/trio_cdp + +import contextvars +import importlib +import itertools +import json +import logging +import pathlib +from collections import defaultdict +from collections.abc import AsyncGenerator, AsyncIterator, Generator +from contextlib import asynccontextmanager, contextmanager +from dataclasses import dataclass +from typing import Any, TypeVar + +import trio +from trio_websocket import ConnectionClosed as WsConnectionClosed +from trio_websocket import connect_websocket_url + +logger = logging.getLogger("trio_cdp") +T = TypeVar("T") +MAX_WS_MESSAGE_SIZE = 2**24 + +devtools = None +version = None + + +def import_devtools(ver): + """Attempt to load the current latest available devtools into the module cache for use later.""" + global devtools + global version + version = ver + base = "selenium.webdriver.common.devtools.v" + try: + devtools = importlib.import_module(f"{base}{ver}") + return devtools + except ModuleNotFoundError: + # Attempt to parse and load the 'most recent' devtools module. This is likely + # because cdp has been updated but selenium python has not been released yet. + devtools_path = pathlib.Path(__file__).parents[1].joinpath("devtools") + versions = tuple(f.name for f in devtools_path.iterdir() if f.is_dir()) + latest = max(int(x[1:]) for x in versions) + selenium_logger = logging.getLogger(__name__) + selenium_logger.debug("Falling back to loading `devtools`: v%s", latest) + devtools = importlib.import_module(f"{base}{latest}") + return devtools + + +_connection_context: contextvars.ContextVar = contextvars.ContextVar("connection_context") +_session_context: contextvars.ContextVar = contextvars.ContextVar("session_context") + + +def get_connection_context(fn_name): + """Look up the current connection. + + If there is no current connection, raise a ``RuntimeError`` with a + helpful message. + """ + try: + return _connection_context.get() + except LookupError: + raise RuntimeError(f"{fn_name}() must be called in a connection context.") + + +def get_session_context(fn_name): + """Look up the current session. + + If there is no current session, raise a ``RuntimeError`` with a + helpful message. + """ + try: + return _session_context.get() + except LookupError: + raise RuntimeError(f"{fn_name}() must be called in a session context.") + + +@contextmanager +def connection_context(connection): + """Context manager installs ``connection`` as the session context for the current Trio task.""" + token = _connection_context.set(connection) + try: + yield + finally: + _connection_context.reset(token) + + +@contextmanager +def session_context(session): + """Context manager installs ``session`` as the session context for the current Trio task.""" + token = _session_context.set(session) + try: + yield + finally: + _session_context.reset(token) + + +def set_global_connection(connection): + """Install ``connection`` in the root context so that it will become the default connection for all tasks. + + This is generally not recommended, except it may be necessary in + certain use cases such as running inside Jupyter notebook. + """ + global _connection_context + _connection_context = contextvars.ContextVar("_connection_context", default=connection) + + +def set_global_session(session): + """Install ``session`` in the root context so that it will become the default session for all tasks. + + This is generally not recommended, except it may be necessary in + certain use cases such as running inside Jupyter notebook. + """ + global _session_context + _session_context = contextvars.ContextVar("_session_context", default=session) + + +class BrowserError(Exception): + """This exception is raised when the browser's response to a command indicates that an error occurred.""" + + def __init__(self, obj): + self.code = obj.get("code") + self.message = obj.get("message") + self.detail = obj.get("data") + + def __str__(self): + return f"BrowserError {self.detail}" + + +class CdpConnectionClosed(WsConnectionClosed): + """Raised when a public method is called on a closed CDP connection.""" + + def __init__(self, reason): + """Constructor. + + Args: + reason: wsproto.frame_protocol.CloseReason + """ + self.reason = reason + + def __repr__(self): + """Return representation.""" + return f"{self.__class__.__name__}<{self.reason}>" + + +class InternalError(Exception): + """This exception is only raised when there is faulty logic in TrioCDP or the integration with PyCDP.""" + + pass + + +@dataclass +class CmEventProxy: + """A proxy object returned by :meth:`CdpBase.wait_for()``. + + After the context manager executes, this proxy object will have a + value set that contains the returned event. + """ + + value: Any = None + + +class CdpBase: + def __init__(self, ws, session_id, target_id): + self.ws = ws + self.session_id = session_id + self.target_id = target_id + self.channels = defaultdict(set) + self.id_iter = itertools.count() + self.inflight_cmd = {} + self.inflight_result = {} + + async def execute(self, cmd: Generator[dict, T, Any]) -> T: + """Execute a command on the server and wait for the result. + + Args: + cmd: any CDP command + + Returns: + a CDP result + """ + cmd_id = next(self.id_iter) + cmd_event = trio.Event() + self.inflight_cmd[cmd_id] = cmd, cmd_event + request = next(cmd) + request["id"] = cmd_id + if self.session_id: + request["sessionId"] = self.session_id + request_str = json.dumps(request) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"Sending CDP message: {cmd_id} {cmd_event}: {request_str}") + try: + await self.ws.send_message(request_str) + except WsConnectionClosed as wcc: + raise CdpConnectionClosed(wcc.reason) from None + await cmd_event.wait() + response = self.inflight_result.pop(cmd_id) + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"Received CDP message: {response}") + if isinstance(response, Exception): + if logger.isEnabledFor(logging.DEBUG): + logger.debug(f"Exception raised by {cmd_event} message: {type(response).__name__}") + raise response + return response + + def listen(self, *event_types, buffer_size=10): + """Listen for events. + + Returns: + An async iterator that iterates over events matching the indicated types. + """ + sender, receiver = trio.open_memory_channel(buffer_size) + for event_type in event_types: + self.channels[event_type].add(sender) + return receiver + + @asynccontextmanager + async def wait_for(self, event_type: type[T], buffer_size=10) -> AsyncGenerator[CmEventProxy, None]: + """Wait for an event of the given type and return it. + + This is an async context manager, so you should open it inside + an async with block. The block will not exit until the indicated + event is received. + """ + sender: trio.MemorySendChannel + receiver: trio.MemoryReceiveChannel + sender, receiver = trio.open_memory_channel(buffer_size) + self.channels[event_type].add(sender) + proxy = CmEventProxy() + yield proxy + async with receiver: + event = await receiver.receive() + proxy.value = event + + def _handle_data(self, data): + """Handle incoming WebSocket data. + + Args: + data: a JSON dictionary + """ + if "id" in data: + self._handle_cmd_response(data) + else: + self._handle_event(data) + + def _handle_cmd_response(self, data: dict): + """Handle a response to a command. + + This will set an event flag that will return control to the + task that called the command. + + Args: + data: response as a JSON dictionary + """ + cmd_id = data["id"] + try: + cmd, event = self.inflight_cmd.pop(cmd_id) + except KeyError: + logger.warning("Got a message with a command ID that does not exist: %s", data) + return + if "error" in data: + # If the server reported an error, convert it to an exception and do + # not process the response any further. + self.inflight_result[cmd_id] = BrowserError(data["error"]) + else: + # Otherwise, continue the generator to parse the JSON result + # into a CDP object. + try: + _ = cmd.send(data["result"]) + raise InternalError("The command's generator function did not exit when expected!") + except StopIteration as exit: + return_ = exit.value + self.inflight_result[cmd_id] = return_ + event.set() + + def _handle_event(self, data: dict): + """Handle an event. + + Args: + data: event as a JSON dictionary + """ + global devtools + if devtools is None: + raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + event = devtools.util.parse_json_event(data) + logger.debug("Received event: %s", event) + to_remove = set() + for sender in self.channels[type(event)]: + try: + sender.send_nowait(event) + except trio.WouldBlock: + logger.error('Unable to send event "%r" due to full channel %s', event, sender) + except trio.BrokenResourceError: + to_remove.add(sender) + if to_remove: + self.channels[type(event)] -= to_remove + + +class CdpSession(CdpBase): + """Contains the state for a CDP session. + + Generally you should not instantiate this object yourself; you should call + :meth:`CdpConnection.open_session`. + """ + + def __init__(self, ws, session_id, target_id): + """Constructor. + + Args: + ws: trio_websocket.WebSocketConnection + session_id: devtools.target.SessionID + target_id: devtools.target.TargetID + """ + super().__init__(ws, session_id, target_id) + + self._dom_enable_count = 0 + self._dom_enable_lock = trio.Lock() + self._page_enable_count = 0 + self._page_enable_lock = trio.Lock() + + @asynccontextmanager + async def dom_enable(self): + """Context manager that executes ``dom.enable()`` when it enters and then calls ``dom.disable()``. + + This keeps track of concurrent callers and only disables DOM + events when all callers have exited. + """ + global devtools + async with self._dom_enable_lock: + self._dom_enable_count += 1 + if self._dom_enable_count == 1: + await self.execute(devtools.dom.enable()) + + yield + + async with self._dom_enable_lock: + self._dom_enable_count -= 1 + if self._dom_enable_count == 0: + await self.execute(devtools.dom.disable()) + + @asynccontextmanager + async def page_enable(self): + """Context manager executes ``page.enable()`` when it enters and then calls ``page.disable()`` when it exits. + + This keeps track of concurrent callers and only disables page + events when all callers have exited. + """ + global devtools + async with self._page_enable_lock: + self._page_enable_count += 1 + if self._page_enable_count == 1: + await self.execute(devtools.page.enable()) + + yield + + async with self._page_enable_lock: + self._page_enable_count -= 1 + if self._page_enable_count == 0: + await self.execute(devtools.page.disable()) + + +class CdpConnection(CdpBase, trio.abc.AsyncResource): + """Contains the connection state for a Chrome DevTools Protocol server. + + CDP can multiplex multiple "sessions" over a single connection. This + class corresponds to the "root" session, i.e. the implicitly created + session that has no session ID. This class is responsible for + reading incoming WebSocket messages and forwarding them to the + corresponding session, as well as handling messages targeted at the + root session itself. You should generally call the + :func:`open_cdp()` instead of instantiating this class directly. + """ + + def __init__(self, ws): + """Constructor. + + Args: + ws: trio_websocket.WebSocketConnection + """ + super().__init__(ws, session_id=None, target_id=None) + self.sessions = {} + + async def aclose(self): + """Close the underlying WebSocket connection. + + This will cause the reader task to gracefully exit when it tries + to read the next message from the WebSocket. All of the public + APIs (``execute()``, ``listen()``, etc.) will raise + ``CdpConnectionClosed`` after the CDP connection is closed. It + is safe to call this multiple times. + """ + await self.ws.aclose() + + @asynccontextmanager + async def open_session(self, target_id) -> AsyncIterator[CdpSession]: + """Context manager opens a session and enables the "simple" style of calling CDP APIs. + + For example, inside a session context, you can call ``await + dom.get_document()`` and it will execute on the current session + automatically. + """ + session = await self.connect_session(target_id) + with session_context(session): + yield session + + async def connect_session(self, target_id) -> "CdpSession": + """Returns a new :class:`CdpSession` connected to the specified target.""" + global devtools + if devtools is None: + raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + session_id = await self.execute(devtools.target.attach_to_target(target_id, True)) + session = CdpSession(self.ws, session_id, target_id) + self.sessions[session_id] = session + return session + + async def _reader_task(self): + """Runs in the background and handles incoming messages. + + Dispatches responses to commands and events to listeners. + """ + global devtools + if devtools is None: + raise RuntimeError("CDP devtools module not loaded. Call import_devtools() first.") + while True: + try: + message = await self.ws.get_message() + except WsConnectionClosed: + # If the WebSocket is closed, we don't want to throw an + # exception from the reader task. Instead we will throw + # exceptions from the public API methods, and we can quietly + # exit the reader task here. + break + try: + data = json.loads(message) + except json.JSONDecodeError: + raise BrowserError({"code": -32700, "message": "Client received invalid JSON", "data": message}) + logger.debug("Received message %r", data) + if "sessionId" in data: + session_id = devtools.target.SessionID(data["sessionId"]) + try: + session = self.sessions[session_id] + except KeyError: + raise BrowserError( + { + "code": -32700, + "message": "Browser sent a message for an invalid session", + "data": f"{session_id!r}", + } + ) + session._handle_data(data) + else: + self._handle_data(data) + + for _, session in self.sessions.items(): + for _, senders in session.channels.items(): + for sender in senders: + sender.close() + + +@asynccontextmanager +async def open_cdp(url) -> AsyncIterator[CdpConnection]: + """Async context manager opens a connection to the browser then closes the connection when the block exits. + + The context manager also sets the connection as the default + connection for the current task, so that commands like ``await + target.get_targets()`` will run on this connection automatically. If + you want to use multiple connections concurrently, it is recommended + to open each on in a separate task. + """ + async with trio.open_nursery() as nursery: + conn = await connect_cdp(nursery, url) + try: + with connection_context(conn): + yield conn + finally: + await conn.aclose() + + +async def connect_cdp(nursery, url) -> CdpConnection: + """Connect to the browser specified by ``url`` and spawn a background task in the specified nursery. + + The ``open_cdp()`` context manager is preferred in most situations. + You should only use this function if you need to specify a custom + nursery. This connection is not automatically closed! You can either + use the connection object as a context manager (``async with + conn:``) or else call ``await conn.aclose()`` on it when you are + done with it. If ``set_context`` is True, then the returned + connection will be installed as the default connection for the + current task. This argument is for unusual use cases, such as + running inside of a notebook. + """ + ws = await connect_websocket_url(nursery, url, max_message_size=MAX_WS_MESSAGE_SIZE) + cdp_conn = CdpConnection(ws) + nursery.start_soon(cdp_conn._reader_task) + return cdp_conn diff --git a/py/private/generate_bidi.bzl b/py/private/generate_bidi.bzl new file mode 100644 index 0000000000000..e072279f85e94 --- /dev/null +++ b/py/private/generate_bidi.bzl @@ -0,0 +1,111 @@ +"""Bazel rule for generating WebDriver BiDi Python modules from CDDL specification.""" + +def _generate_bidi_impl(ctx): + """Implementation of the generate_bidi rule.""" + + cddl_file = ctx.file.cddl_file + manifest_file = ctx.file.enhancements_manifest + generator = ctx.executable.generator + output_dir = ctx.attr.module_name + spec_version = ctx.attr.spec_version + + # The generator creates BiDi modules from the CDDL spec + # Using snake_case naming convention for Python files + module_names = [ + "browser", + "browsing_context", + "common", + "console", + "emulation", + "input", + "log", + "network", + "permissions", + "script", + "session", + "storage", + "webextension", + ] + + # Declare all output files + module_files = [ + ctx.actions.declare_file(output_dir + "/" + name + ".py") + for name in module_names + ] + init_file = ctx.actions.declare_file(output_dir + "/__init__.py") + py_typed = ctx.actions.declare_file(output_dir + "/py.typed") + + gen_outputs = module_files + [init_file, py_typed] + + # Copy static extra_srcs into the output directory + extra_outputs = [] + for src in ctx.files.extra_srcs: + out = ctx.actions.declare_file(output_dir + "/" + src.basename) + ctx.actions.symlink(output = out, target_file = src) + extra_outputs.append(out) + + outputs = gen_outputs + extra_outputs + + # Output directory for the generator + output_base = init_file.dirname + + # Build the command to run the generator + args = [ + cddl_file.path, + output_base, + spec_version, + ] + + # Add enhancement manifest if provided + inputs = [cddl_file] + if manifest_file: + args.extend(["--enhancements-manifest", manifest_file.path]) + inputs.append(manifest_file) + + ctx.actions.run( + inputs = inputs, + outputs = gen_outputs, + executable = generator, + arguments = args, + use_default_shell_env = True, + ) + + return [DefaultInfo(files = depset(outputs))] + + +generate_bidi = rule( + implementation = _generate_bidi_impl, + attrs = { + "cddl_file": attr.label( + allow_single_file = [".cddl"], + mandatory = True, + doc = "CDDL specification file", + ), + "enhancements_manifest": attr.label( + allow_single_file = [".py"], + mandatory = False, + doc = "Enhancement manifest Python file (optional)", + ), + "extra_srcs": attr.label_list( + allow_files = [".py"], + mandatory = False, + default = [], + doc = "Additional static Python files to copy verbatim into the output directory", + ), + "generator": attr.label( + executable = True, + cfg = "exec", + mandatory = True, + doc = "Generator script (e.g., generate_bidi.py)", + ), + "module_name": attr.string( + mandatory = True, + doc = "Name of the module being generated (e.g., 'selenium/webdriver/common/bidi')", + ), + "spec_version": attr.string( + default = "1.0", + doc = "WebDriver BiDi specification version", + ), + }, + doc = "Generates Python WebDriver BiDi modules from CDDL specification", +) diff --git a/py/selenium/common/exceptions.py b/py/selenium/common/exceptions.py index c45f530002a8f..7ec809eb20b18 100644 --- a/py/selenium/common/exceptions.py +++ b/py/selenium/common/exceptions.py @@ -27,7 +27,10 @@ class WebDriverException(Exception): """Base webdriver exception.""" def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: super().__init__() self.msg = msg @@ -73,7 +76,10 @@ class NoSuchElementException(WebDriverException): """ def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#nosuchelementexception" @@ -111,9 +117,14 @@ class StaleElementReferenceException(WebDriverException): """ def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: - with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#staleelementreferenceexception" + with_support = ( + f"{msg}; {SUPPORT_MSG} {ERROR_URL}#staleelementreferenceexception" + ) super().__init__(with_support, screen, stacktrace) @@ -161,7 +172,10 @@ class ElementNotVisibleException(InvalidElementStateException): """ def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementnotvisibleexception" @@ -172,9 +186,14 @@ class ElementNotInteractableException(InvalidElementStateException): """Thrown when element interactions will hit another element due to paint order.""" def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: - with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementnotinteractableexception" + with_support = ( + f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementnotinteractableexception" + ) super().__init__(with_support, screen, stacktrace) @@ -213,7 +232,10 @@ class InvalidSelectorException(WebDriverException): """ def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#invalidselectorexception" @@ -252,9 +274,14 @@ class ElementClickInterceptedException(WebDriverException): """Thrown when element click fails because another element obscures it.""" def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: - with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementclickinterceptedexception" + with_support = ( + f"{msg}; {SUPPORT_MSG} {ERROR_URL}#elementclickinterceptedexception" + ) super().__init__(with_support, screen, stacktrace) @@ -271,7 +298,10 @@ class InvalidSessionIdException(WebDriverException): """Thrown when the given session id is not in the list of active sessions.""" def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#invalidsessionidexception" @@ -282,7 +312,10 @@ class SessionNotCreatedException(WebDriverException): """A new session could not be created.""" def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}#sessionnotcreatedexception" @@ -297,7 +330,10 @@ class NoSuchDriverException(WebDriverException): """Raised when driver is not specified and cannot be located.""" def __init__( - self, msg: Any | None = None, screen: str | None = None, stacktrace: Sequence[str] | None = None + self, + msg: Any | None = None, + screen: str | None = None, + stacktrace: Sequence[str] | None = None, ) -> None: with_support = f"{msg}; {SUPPORT_MSG} {ERROR_URL}/driver_location" diff --git a/py/selenium/webdriver/common/bidi/__init__.py b/py/selenium/webdriver/common/bidi/__init__.py index a5b1e6f85a09e..7be7bd4f73856 100644 --- a/py/selenium/webdriver/common/bidi/__init__.py +++ b/py/selenium/webdriver/common/bidi/__init__.py @@ -1,16 +1,30 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. + +from __future__ import annotations + +from .browser import Browser +from .browsing_context import BrowsingContext +from .emulation import Emulation +from .input import Input +from .log import Log +from .network import Network +from .script import Script +from .session import Session +from .storage import Storage +from .webextension import WebExtension + +__all__ = [ + "Browser", + "BrowsingContext", + "Emulation", + "Input", + "Log", + "Network", + "Script", + "Session", + "Storage", + "WebExtension", +] diff --git a/py/selenium/webdriver/common/bidi/_event_manager.py b/py/selenium/webdriver/common/bidi/_event_manager.py new file mode 100644 index 0000000000000..216a5b8eccb70 --- /dev/null +++ b/py/selenium/webdriver/common/bidi/_event_manager.py @@ -0,0 +1,186 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Shared event management helpers for generated WebDriver BiDi modules. + +``EventConfig``, ``_EventWrapper``, and ``_EventManager`` are emitted +identically into every generated module that exposes events. Rather than +duplicating ~160 lines of code across all of those modules, they are defined +once here and imported by the generated files. +""" + +from __future__ import annotations + +import threading +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +from selenium.webdriver.common.bidi.session import Session + + +@dataclass +class EventConfig: + """Configuration for a BiDi event.""" + + event_key: str + bidi_event: str + event_class: type + + +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. + + Args: + params: Raw BiDi event params with camelCase keys. + + Returns: + An instance of the dataclass, or the raw dict on failure. + """ + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) + + +class _EventManager: + """Manages event subscriptions and callbacks.""" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + """Subscribe to a BiDi event if not already subscribed.""" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist.""" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: + session = Session(self.conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event]["callbacks"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() diff --git a/py/selenium/webdriver/common/bidi/browser.py b/py/selenium/webdriver/common/bidi/browser.py index 5b449ae69276a..a4ec770fbb135 100644 --- a/py/selenium/webdriver/common/bidi/browser.py +++ b/py/selenium/webdriver/common/bidi/browser.py @@ -1,280 +1,353 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import os +# WebDriver BiDi module: browser +from __future__ import annotations + from typing import Any +from .common import command_builder +from dataclasses import dataclass +from dataclasses import field + + +def transform_download_params( + allowed: bool | None, + destination_folder: str | None, +) -> dict[str, Any] | None: + """Transform download parameters into download_behavior object. + + Args: + allowed: Whether downloads are allowed + destination_folder: Destination folder for downloads (accepts str or + pathlib.Path; will be coerced to str) + + Returns: + Dictionary representing the download_behavior object, or None if allowed is None + """ + if allowed is True: + return { + "type": "allowed", + # Coerce pathlib.Path (or any path-like) to str so the BiDi + # protocol always receives a plain JSON string. + "destinationFolder": str(destination_folder) if destination_folder is not None else None, + } + elif allowed is False: + return {"type": "denied"} + else: # None — reset to browser default (sent as JSON null) + return None + + +def validate_download_behavior( + allowed: bool | None, + destination_folder: str | None, + user_contexts: Any | None = None, +) -> None: + """Validate download behavior parameters. + + Args: + allowed: Whether downloads are allowed + destination_folder: Destination folder for downloads + user_contexts: Optional list of user contexts + + Raises: + ValueError: If parameters are invalid + """ + if allowed is True and not destination_folder: + raise ValueError("destination_folder is required when allowed=True") + if allowed is False and destination_folder: + raise ValueError("destination_folder should not be provided when allowed=False") + + +@dataclass +class ClientWindowInfo: + """ClientWindowInfo.""" + + active: bool | None = None + client_window: Any | None = None + height: Any | None = None + state: Any | None = None + width: Any | None = None + x: Any | None = None + y: Any | None = None + + def get_client_window(self): + """Get the client window ID.""" + return self.client_window -from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi.session import UserPromptHandler -from selenium.webdriver.common.proxy import Proxy + def get_state(self): + """Get the client window state.""" + return self.state + def get_width(self): + """Get the client window width.""" + return self.width -class ClientWindowState: - """Represents a window state.""" + def get_height(self): + """Get the client window height.""" + return self.height - FULLSCREEN = "fullscreen" - MAXIMIZED = "maximized" - MINIMIZED = "minimized" - NORMAL = "normal" + def is_active(self): + """Check if the client window is active.""" + return self.active - VALID_STATES = {FULLSCREEN, MAXIMIZED, MINIMIZED, NORMAL} + def get_x(self): + """Get the client window X position.""" + return self.x + def get_y(self): + """Get the client window Y position.""" + return self.y -class ClientWindowInfo: - """Represents a client window information.""" - def __init__( - self, - client_window: str, - state: str, - width: int, - height: int, - x: int, - y: int, - active: bool, - ): - self.client_window = client_window - self.state = state - self.width = width - self.height = height - self.x = x - self.y = y - self.active = active - - def get_state(self) -> str: - """Gets the state of the client window. - - Returns: - str: The state of the client window (one of the ClientWindowState constants). - """ - return self.state - def get_client_window(self) -> str: - """Gets the client window identifier. +@dataclass +class UserContextInfo: + """UserContextInfo.""" - Returns: - str: The client window identifier. - """ - return self.client_window + user_context: Any | None = None - def get_width(self) -> int: - """Gets the width of the client window. - Returns: - int: The width of the client window. - """ - return self.width +@dataclass +class CreateUserContextParameters: + """CreateUserContextParameters.""" - def get_height(self) -> int: - """Gets the height of the client window. + accept_insecure_certs: bool | None = None + proxy: Any | None = None + unhandled_prompt_behavior: Any | None = None - Returns: - int: The height of the client window. - """ - return self.height - def get_x(self) -> int: - """Gets the x coordinate of the client window. +@dataclass +class GetClientWindowsResult: + """GetClientWindowsResult.""" - Returns: - int: The x coordinate of the client window. - """ - return self.x + client_windows: list[Any] = field(default_factory=list) - def get_y(self) -> int: - """Gets the y coordinate of the client window. - Returns: - int: The y coordinate of the client window. - """ - return self.y +@dataclass +class GetUserContextsResult: + """GetUserContextsResult.""" - def is_active(self) -> bool: - """Checks if the client window is active. + user_contexts: list[Any] = field(default_factory=list) - Returns: - bool: True if the client window is active, False otherwise. - """ - return self.active - @classmethod - def from_dict(cls, data: dict) -> "ClientWindowInfo": - """Creates a ClientWindowInfo instance from a dictionary. +@dataclass +class RemoveUserContextParameters: + """RemoveUserContextParameters.""" - Args: - data: A dictionary containing the client window information. + user_context: Any | None = None - Returns: - ClientWindowInfo: A new instance of ClientWindowInfo. - Raises: - ValueError: If required fields are missing or have invalid types. - """ - try: - client_window = data["clientWindow"] - if not isinstance(client_window, str): - raise ValueError("clientWindow must be a string") - - state = data["state"] - if not isinstance(state, str): - raise ValueError("state must be a string") - if state not in ClientWindowState.VALID_STATES: - raise ValueError(f"Invalid state: {state}. Must be one of {ClientWindowState.VALID_STATES}") - - width = data["width"] - if not isinstance(width, int) or width < 0: - raise ValueError(f"width must be a non-negative integer, got {width}") - - height = data["height"] - if not isinstance(height, int) or height < 0: - raise ValueError(f"height must be a non-negative integer, got {height}") - - x = data["x"] - if not isinstance(x, int): - raise ValueError(f"x must be an integer, got {type(x).__name__}") - - y = data["y"] - if not isinstance(y, int): - raise ValueError(f"y must be an integer, got {type(y).__name__}") - - active = data["active"] - if not isinstance(active, bool): - raise ValueError("active must be a boolean") - - return cls( - client_window=client_window, - state=state, - width=width, - height=height, - x=x, - y=y, - active=active, - ) - except (KeyError, TypeError) as e: - raise ValueError(f"Invalid data format for ClientWindowInfo: {e}") from e +@dataclass +class SetClientWindowStateParameters: + """SetClientWindowStateParameters.""" + client_window: Any | None = None -class Browser: - """BiDi implementation of the browser module.""" - def __init__(self, conn): - self.conn = conn +@dataclass +class ClientWindowRectState: + """ClientWindowRectState.""" - def create_user_context( - self, - accept_insecure_certs: bool | None = None, - proxy: Proxy | None = None, - unhandled_prompt_behavior: UserPromptHandler | None = None, - ) -> str: - """Creates a new user context. + state: str = field(default="normal", init=False) + width: Any | None = None + height: Any | None = None + x: Any | None = None + y: Any | None = None - Args: - accept_insecure_certs: Optional flag to accept insecure TLS certificates. - proxy: Optional proxy configuration for the user context. - unhandled_prompt_behavior: Optional configuration for handling user prompts. - Returns: - str: The ID of the created user context. - """ - params: dict[str, Any] = {} +@dataclass +class SetDownloadBehaviorParameters: + """SetDownloadBehaviorParameters.""" - if accept_insecure_certs is not None: - params["acceptInsecureCerts"] = accept_insecure_certs + download_behavior: Any | None = None + user_contexts: list[Any] = field(default_factory=list) - if proxy is not None: - params["proxy"] = proxy.to_bidi_dict() - if unhandled_prompt_behavior is not None: - params["unhandledPromptBehavior"] = unhandled_prompt_behavior.to_dict() +@dataclass +class DownloadBehaviorAllowed: + """DownloadBehaviorAllowed.""" - result = self.conn.execute(command_builder("browser.createUserContext", params)) - return result["userContext"] + type: str = field(default="allowed", init=False) + destination_folder: str | None = None - def get_user_contexts(self) -> list[str]: - """Gets all user contexts. - Returns: - List[str]: A list of user context IDs. - """ - result = self.conn.execute(command_builder("browser.getUserContexts", {})) - return [context_info["userContext"] for context_info in result["userContexts"]] +@dataclass +class DownloadBehaviorDenied: + """DownloadBehaviorDenied.""" - def remove_user_context(self, user_context_id: str) -> None: - """Removes a user context. + type: str = field(default="denied", init=False) - Args: - user_context_id: The ID of the user context to remove. - Raises: - ValueError: If the user context ID is "default" or does not exist. - """ - if user_context_id == "default": - raise ValueError("Cannot remove the default user context") +class ClientWindowNamedState: + """Named states for a browser client window.""" - params = {"userContext": user_context_id} - self.conn.execute(command_builder("browser.removeUserContext", params)) + FULLSCREEN = "fullscreen" + MAXIMIZED = "maximized" + MINIMIZED = "minimized" + NORMAL = "normal" - def get_client_windows(self) -> list[ClientWindowInfo]: - """Gets all client windows. +class Browser: + """WebDriver BiDi browser module.""" - Returns: - List[ClientWindowInfo]: A list of client window information. - """ - result = self.conn.execute(command_builder("browser.getClientWindows", {})) - return [ClientWindowInfo.from_dict(window) for window in result["clientWindows"]] + def __init__(self, conn) -> None: + self._conn = conn + + def close(self): + """Execute browser.close.""" + params = { + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.close", params) + result = self._conn.execute(cmd) + return result + + def create_user_context( + self, + accept_insecure_certs: bool | None = None, + proxy: Any | None = None, + unhandled_prompt_behavior: Any | None = None, + ): + """Execute browser.createUserContext.""" + if proxy and hasattr(proxy, 'to_bidi_dict'): + proxy = proxy.to_bidi_dict() + + if unhandled_prompt_behavior and hasattr(unhandled_prompt_behavior, 'to_bidi_dict'): + unhandled_prompt_behavior = unhandled_prompt_behavior.to_bidi_dict() + + params = { + "acceptInsecureCerts": accept_insecure_certs, + "proxy": proxy, + "unhandledPromptBehavior": unhandled_prompt_behavior, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.createUserContext", params) + result = self._conn.execute(cmd) + if result and "userContext" in result: + extracted = result.get("userContext") + return extracted + return result + + def get_client_windows(self): + """Execute browser.getClientWindows.""" + params = { + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.getClientWindows", params) + result = self._conn.execute(cmd) + if result and "clientWindows" in result: + items = result.get("clientWindows", []) + return [ + ClientWindowInfo( + active=item.get("active"), + client_window=item.get("clientWindow"), + height=item.get("height"), + state=item.get("state"), + width=item.get("width"), + x=item.get("x"), + y=item.get("y") + ) + for item in items + if isinstance(item, dict) + ] + return [] + + def get_user_contexts(self): + """Execute browser.getUserContexts.""" + params = { + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.getUserContexts", params) + result = self._conn.execute(cmd) + if result and "userContexts" in result: + items = result.get("userContexts", []) + return [ + item.get("userContext") + for item in items + if isinstance(item, dict) + ] + return [] + + def remove_user_context(self, user_context: Any | None = None): + """Execute browser.removeUserContext.""" + if user_context is None: + raise TypeError("remove_user_context() missing required argument: {{snake_param!r}}") + + params = { + "userContext": user_context, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browser.removeUserContext", params) + result = self._conn.execute(cmd) + return result def set_download_behavior( self, - *, allowed: bool | None = None, - destination_folder: str | os.PathLike | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set the download behavior for the browser or specific user contexts. + destination_folder: str | None = None, + user_contexts: list[Any] | None = None, + ): + """Set the download behavior for the browser. Args: - allowed: True to allow downloads, False to deny downloads, or None to - clear download behavior (revert to default). - destination_folder: Required when allowed is True. Specifies the folder - to store downloads in. - user_contexts: Optional list of user context IDs to apply this - behavior to. If omitted, updates the default behavior. + allowed: ``True`` to allow downloads, ``False`` to deny, or ``None`` + to reset to browser default (sends ``null`` to the protocol). + destination_folder: Destination folder for downloads. Required when + ``allowed=True``. Accepts a string or :class:`pathlib.Path`. + user_contexts: Optional list of user context IDs. Raises: - ValueError: If allowed=True and destination_folder is missing, or if - allowed=False and destination_folder is provided. + ValueError: If *allowed* is ``True`` and *destination_folder* is + omitted, or ``False`` and *destination_folder* is provided. """ - params: dict[str, Any] = {} - - if allowed is None: - params["downloadBehavior"] = None - else: - if allowed: - if not destination_folder: - raise ValueError("destination_folder is required when allowed=True.") - params["downloadBehavior"] = { - "type": "allowed", - "destinationFolder": os.fspath(destination_folder), - } - else: - if destination_folder: - raise ValueError("destination_folder should not be provided when allowed=False.") - params["downloadBehavior"] = {"type": "denied"} - + validate_download_behavior( + allowed=allowed, + destination_folder=destination_folder, + user_contexts=user_contexts, + ) + download_behavior = transform_download_params(allowed, destination_folder) + # downloadBehavior is a REQUIRED field in the BiDi spec (can be null but + # must be present). Do NOT use a generic None-filter on it. + params: dict = {"downloadBehavior": download_behavior} if user_contexts is not None: params["userContexts"] = user_contexts + cmd = command_builder("browser.setDownloadBehavior", params) + return self._conn.execute(cmd) + def set_client_window_state( + self, + client_window: Any | None = None, + state: Any | None = None, + ): + """Set the client window state. - self.conn.execute(command_builder("browser.setDownloadBehavior", params)) + Args: + client_window: The client window ID to apply the state to. + state: The window state to set. Can be one of: + - A string: "fullscreen", "maximized", "minimized", "normal" + - A ClientWindowRectState object with width, height, x, y + - A dict representing the state + + Raises: + ValueError: If client_window is not provided or state is invalid. + """ + if client_window is None: + raise ValueError("client_window is required") + if state is None: + raise ValueError("state is required") + + # Serialize ClientWindowRectState if needed + state_param = state + if hasattr(state, '__dataclass_fields__'): + # It's a dataclass, convert to dict + state_param = { + k: v for k, v in state.__dict__.items() + if v is not None + } + + params = { + "clientWindow": client_window, + "state": state_param, + } + cmd = command_builder("browser.setClientWindowState", params) + return self._conn.execute(cmd) diff --git a/py/selenium/webdriver/common/bidi/browsing_context.py b/py/selenium/webdriver/common/bidi/browsing_context.py index e8ae150342bda..c5489ce865180 100644 --- a/py/selenium/webdriver/common/bidi/browsing_context.py +++ b/py/selenium/webdriver/common/bidi/browsing_context.py @@ -1,35 +1,22 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# WebDriver BiDi module: browsingContext +from __future__ import annotations +from typing import Any +from .common import command_builder +from dataclasses import dataclass +from dataclasses import field import threading from collections.abc import Callable -from dataclasses import dataclass -from typing import Any - -from typing_extensions import Sentinel - -from selenium.webdriver.common.bidi.common import command_builder from selenium.webdriver.common.bidi.session import Session -UNDEFINED = Sentinel("UNDEFINED") - class ReadinessState: - """Represents the stage of document loading at which a navigation command will return.""" + """ReadinessState.""" NONE = "none" INTERACTIVE = "interactive" @@ -37,576 +24,493 @@ class ReadinessState: class UserPromptType: - """Represents the possible user prompt types.""" + """UserPromptType.""" ALERT = "alert" - BEFORE_UNLOAD = "beforeunload" + BEFOREUNLOAD = "beforeunload" CONFIRM = "confirm" PROMPT = "prompt" -class NavigationInfo: - """Provides details of an ongoing navigation.""" +class CreateType: + """CreateType.""" - def __init__( - self, - context: str, - navigation: str | None, - timestamp: int, - url: str, - ): - self.context = context - self.navigation = navigation - self.timestamp = timestamp - self.url = url + TAB = "tab" + WINDOW = "window" - @classmethod - def from_json(cls, json: dict) -> "NavigationInfo": - """Creates a NavigationInfo instance from a dictionary. - Args: - json: A dictionary containing the navigation information. +class DownloadCompleteParams: + """DownloadCompleteParams.""" - Returns: - A new instance of NavigationInfo. - """ - context = json.get("context") - if context is None or not isinstance(context, str): - raise ValueError("context is required and must be a string") + COMPLETE = "complete" - navigation = json.get("navigation") - if navigation is not None and not isinstance(navigation, str): - raise ValueError("navigation must be a string") - timestamp = json.get("timestamp") - if timestamp is None or not isinstance(timestamp, int) or timestamp < 0: - raise ValueError("timestamp is required and must be a non-negative integer") +@dataclass +class Info: + """Info.""" - url = json.get("url") - if url is None or not isinstance(url, str): - raise ValueError("url is required and must be a string") + children: Any | None = None + client_window: Any | None = None + context: Any | None = None + original_opener: Any | None = None + url: str | None = None + user_context: Any | None = None + parent: Any | None = None - return cls(context, navigation, timestamp, url) +@dataclass +class AccessibilityLocator: + """AccessibilityLocator.""" -class BrowsingContextInfo: - """Represents the properties of a navigable.""" + type: str = field(default="accessibility", init=False) + name: str | None = None + role: str | None = None - def __init__( - self, - context: str, - url: str, - children: list["BrowsingContextInfo"] | None, - client_window: str, - user_context: str, - parent: str | None = None, - original_opener: str | None = None, - ): - self.context = context - self.url = url - self.children = children - self.parent = parent - self.user_context = user_context - self.original_opener = original_opener - self.client_window = client_window - @classmethod - def from_json(cls, json: dict) -> "BrowsingContextInfo": - """Creates a BrowsingContextInfo instance from a dictionary. +@dataclass +class CssLocator: + """CssLocator.""" - Args: - json: A dictionary containing the browsing context information. + type: str = field(default="css", init=False) + value: str | None = None - Returns: - A new instance of BrowsingContextInfo. - """ - children = None - raw_children = json.get("children") - if raw_children is not None: - if not isinstance(raw_children, list): - raise ValueError("children must be a list if provided") - - children = [] - for child in raw_children: - if not isinstance(child, dict): - raise ValueError(f"Each child must be a dictionary, got {type(child)}") - children.append(BrowsingContextInfo.from_json(child)) - - context = json.get("context") - if context is None or not isinstance(context, str): - raise ValueError("context is required and must be a string") - - url = json.get("url") - if url is None or not isinstance(url, str): - raise ValueError("url is required and must be a string") - - parent = json.get("parent") - if parent is not None and not isinstance(parent, str): - raise ValueError("parent must be a string if provided") - - user_context = json.get("userContext") - if user_context is None or not isinstance(user_context, str): - raise ValueError("userContext is required and must be a string") - - original_opener = json.get("originalOpener") - if original_opener is not None and not isinstance(original_opener, str): - raise ValueError("originalOpener must be a string if provided") - - client_window = json.get("clientWindow") - if client_window is None or not isinstance(client_window, str): - raise ValueError("clientWindow is required and must be a string") - - return cls( - context=context, - url=url, - children=children, - client_window=client_window, - user_context=user_context, - parent=parent, - original_opener=original_opener, - ) +@dataclass +class ContextLocator: + """ContextLocator.""" -class DownloadWillBeginParams(NavigationInfo): - """Parameters for the downloadWillBegin event.""" + type: str = field(default="context", init=False) + context: Any | None = None - def __init__( - self, - context: str, - navigation: str | None, - timestamp: int, - url: str, - suggested_filename: str, - ): - super().__init__(context, navigation, timestamp, url) - self.suggested_filename = suggested_filename - @classmethod - def from_json(cls, json: dict) -> "DownloadWillBeginParams": - nav_info = NavigationInfo.from_json(json) - - suggested_filename = json.get("suggestedFilename") - if suggested_filename is None or not isinstance(suggested_filename, str): - raise ValueError("suggestedFilename is required and must be a string") - - return cls( - context=nav_info.context, - navigation=nav_info.navigation, - timestamp=nav_info.timestamp, - url=nav_info.url, - suggested_filename=suggested_filename, - ) +@dataclass +class InnerTextLocator: + """InnerTextLocator.""" + type: str = field(default="innerText", init=False) + value: str | None = None + ignore_case: bool | None = None + match_type: Any | None = None + max_depth: Any | None = None -class UserPromptOpenedParams: - """Parameters for the userPromptOpened event.""" - def __init__( - self, - context: str, - handler: str, - message: str, - type: str, - default_value: str | None = None, - ): - self.context = context - self.handler = handler - self.message = message - self.type = type - self.default_value = default_value +@dataclass +class XPathLocator: + """XPathLocator.""" - @classmethod - def from_json(cls, json: dict) -> "UserPromptOpenedParams": - """Creates a UserPromptOpenedParams instance from a dictionary. + type: str = field(default="xpath", init=False) + value: str | None = None - Args: - json: A dictionary containing the user prompt parameters. - Returns: - A new instance of UserPromptOpenedParams. - """ - context = json.get("context") - if context is None or not isinstance(context, str): - raise ValueError("context is required and must be a string") - - handler = json.get("handler") - if handler is None or not isinstance(handler, str): - raise ValueError("handler is required and must be a string") - - message = json.get("message") - if message is None or not isinstance(message, str): - raise ValueError("message is required and must be a string") - - type_value = json.get("type") - if type_value is None or not isinstance(type_value, str): - raise ValueError("type is required and must be a string") - - default_value = json.get("defaultValue") - if default_value is not None and not isinstance(default_value, str): - raise ValueError("defaultValue must be a string if provided") - - return cls( - context=context, - handler=handler, - message=message, - type=type_value, - default_value=default_value, - ) +@dataclass +class BaseNavigationInfo: + """BaseNavigationInfo.""" + context: Any | None = None + navigation: Any | None = None + timestamp: Any | None = None + url: str | None = None -class UserPromptClosedParams: - """Parameters for the userPromptClosed event.""" - def __init__( - self, - context: str, - accepted: bool, - type: str, - user_text: str | None = None, - ): - self.context = context - self.accepted = accepted - self.type = type - self.user_text = user_text +@dataclass +class ActivateParameters: + """ActivateParameters.""" - @classmethod - def from_json(cls, json: dict) -> "UserPromptClosedParams": - """Creates a UserPromptClosedParams instance from a dictionary. + context: Any | None = None - Args: - json: A dictionary containing the user prompt closed parameters. - Returns: - A new instance of UserPromptClosedParams. - """ - context = json.get("context") - if context is None or not isinstance(context, str): - raise ValueError("context is required and must be a string") - - accepted = json.get("accepted") - if accepted is None or not isinstance(accepted, bool): - raise ValueError("accepted is required and must be a boolean") - - type_value = json.get("type") - if type_value is None or not isinstance(type_value, str): - raise ValueError("type is required and must be a string") - - user_text = json.get("userText") - if user_text is not None and not isinstance(user_text, str): - raise ValueError("userText must be a string if provided") - - return cls( - context=context, - accepted=accepted, - type=type_value, - user_text=user_text, - ) +@dataclass +class CaptureScreenshotParameters: + """CaptureScreenshotParameters.""" + context: Any | None = None + format: Any | None = None + clip: Any | None = None -class HistoryUpdatedParams: - """Parameters for the historyUpdated event.""" - def __init__( - self, - context: str, - timestamp: int, - url: str, - ): - self.context = context - self.timestamp = timestamp - self.url = url +@dataclass +class ImageFormat: + """ImageFormat.""" - @classmethod - def from_json(cls, json: dict) -> "HistoryUpdatedParams": - """Creates a HistoryUpdatedParams instance from a dictionary. + type: str | None = None + quality: Any | None = None - Args: - json: A dictionary containing the history updated parameters. - Returns: - A new instance of HistoryUpdatedParams. - """ - context = json.get("context") - if context is None or not isinstance(context, str): - raise ValueError("context is required and must be a string") - - timestamp = json.get("timestamp") - if timestamp is None or not isinstance(timestamp, int) or timestamp < 0: - raise ValueError("timestamp is required and must be a non-negative integer") - - url = json.get("url") - if url is None or not isinstance(url, str): - raise ValueError("url is required and must be a string") - - return cls( - context=context, - timestamp=timestamp, - url=url, - ) +@dataclass +class ElementClipRectangle: + """ElementClipRectangle.""" + type: str = field(default="element", init=False) + element: Any | None = None -class DownloadCanceledParams(NavigationInfo): - def __init__( - self, - context: str, - navigation: str | None, - timestamp: int, - url: str, - status: str = "canceled", - ): - super().__init__(context, navigation, timestamp, url) - self.status = status - @classmethod - def from_json(cls, json: dict) -> "DownloadCanceledParams": - nav_info = NavigationInfo.from_json(json) - - status = json.get("status") - if status is None or status != "canceled": - raise ValueError("status is required and must be 'canceled'") - - return cls( - context=nav_info.context, - navigation=nav_info.navigation, - timestamp=nav_info.timestamp, - url=nav_info.url, - status=status, - ) +@dataclass +class BoxClipRectangle: + """BoxClipRectangle.""" + type: str = field(default="box", init=False) + x: Any | None = None + y: Any | None = None + width: Any | None = None + height: Any | None = None -class DownloadCompleteParams(NavigationInfo): - def __init__( - self, - context: str, - navigation: str | None, - timestamp: int, - url: str, - status: str = "complete", - filepath: str | None = None, - ): - super().__init__(context, navigation, timestamp, url) - self.status = status - self.filepath = filepath - @classmethod - def from_json(cls, json: dict) -> "DownloadCompleteParams": - nav_info = NavigationInfo.from_json(json) - - status = json.get("status") - if status is None or status != "complete": - raise ValueError("status is required and must be 'complete'") - - filepath = json.get("filepath") - if filepath is not None and not isinstance(filepath, str): - raise ValueError("filepath must be a string if provided") - - return cls( - context=nav_info.context, - navigation=nav_info.navigation, - timestamp=nav_info.timestamp, - url=nav_info.url, - status=status, - filepath=filepath, - ) +@dataclass +class CaptureScreenshotResult: + """CaptureScreenshotResult.""" + data: str | None = None -class DownloadEndParams: - """Parameters for the downloadEnd event.""" - def __init__( - self, - download_params: DownloadCanceledParams | DownloadCompleteParams, - ): - self.download_params = download_params +@dataclass +class CloseParameters: + """CloseParameters.""" - @classmethod - def from_json(cls, json: dict) -> "DownloadEndParams": - status = json.get("status") - if status == "canceled": - return cls(DownloadCanceledParams.from_json(json)) - elif status == "complete": - return cls(DownloadCompleteParams.from_json(json)) - else: - raise ValueError("status must be either 'canceled' or 'complete'") + context: Any | None = None + prompt_unload: bool | None = None -class ContextCreated: - """Event class for browsingContext.contextCreated event.""" +@dataclass +class CreateParameters: + """CreateParameters.""" - event_class = "browsingContext.contextCreated" + type: Any | None = None + reference_context: Any | None = None + background: bool | None = None + user_context: Any | None = None - @classmethod - def from_json(cls, json: dict): - if isinstance(json, BrowsingContextInfo): - return json - return BrowsingContextInfo.from_json(json) +@dataclass +class CreateResult: + """CreateResult.""" -class ContextDestroyed: - """Event class for browsingContext.contextDestroyed event.""" + context: Any | None = None - event_class = "browsingContext.contextDestroyed" - @classmethod - def from_json(cls, json: dict): - if isinstance(json, BrowsingContextInfo): - return json - return BrowsingContextInfo.from_json(json) +@dataclass +class GetTreeParameters: + """GetTreeParameters.""" + max_depth: Any | None = None + root: Any | None = None -class NavigationStarted: - """Event class for browsingContext.navigationStarted event.""" - event_class = "browsingContext.navigationStarted" +@dataclass +class GetTreeResult: + """GetTreeResult.""" - @classmethod - def from_json(cls, json: dict): - if isinstance(json, NavigationInfo): - return json - return NavigationInfo.from_json(json) + contexts: Any | None = None -class NavigationCommitted: - """Event class for browsingContext.navigationCommitted event.""" +@dataclass +class HandleUserPromptParameters: + """HandleUserPromptParameters.""" - event_class = "browsingContext.navigationCommitted" + context: Any | None = None + accept: bool | None = None + user_text: str | None = None - @classmethod - def from_json(cls, json: dict): - if isinstance(json, NavigationInfo): - return json - return NavigationInfo.from_json(json) +@dataclass +class LocateNodesParameters: + """LocateNodesParameters.""" -class NavigationFailed: - """Event class for browsingContext.navigationFailed event.""" + context: Any | None = None + locator: Any | None = None + serialization_options: Any | None = None + start_nodes: list[Any] = field(default_factory=list) - event_class = "browsingContext.navigationFailed" - @classmethod - def from_json(cls, json: dict): - if isinstance(json, NavigationInfo): - return json - return NavigationInfo.from_json(json) +@dataclass +class LocateNodesResult: + """LocateNodesResult.""" + nodes: list[Any] = field(default_factory=list) -class NavigationAborted: - """Event class for browsingContext.navigationAborted event.""" - event_class = "browsingContext.navigationAborted" +@dataclass +class NavigateParameters: + """NavigateParameters.""" - @classmethod - def from_json(cls, json: dict): - if isinstance(json, NavigationInfo): - return json - return NavigationInfo.from_json(json) + context: Any | None = None + url: str | None = None + wait: Any | None = None -class DomContentLoaded: - """Event class for browsingContext.domContentLoaded event.""" +@dataclass +class NavigateResult: + """NavigateResult.""" - event_class = "browsingContext.domContentLoaded" + navigation: Any | None = None + url: str | None = None - @classmethod - def from_json(cls, json: dict): - if isinstance(json, NavigationInfo): - return json - return NavigationInfo.from_json(json) +@dataclass +class PrintParameters: + """PrintParameters.""" -class Load: - """Event class for browsingContext.load event.""" + context: Any | None = None + background: bool | None = None + margin: Any | None = None + page: Any | None = None + scale: Any | None = None + shrink_to_fit: bool | None = None - event_class = "browsingContext.load" - @classmethod - def from_json(cls, json: dict): - if isinstance(json, NavigationInfo): - return json - return NavigationInfo.from_json(json) +@dataclass +class PrintMarginParameters: + """PrintMarginParameters.""" + bottom: Any | None = None + left: Any | None = None + right: Any | None = None + top: Any | None = None -class FragmentNavigated: - """Event class for browsingContext.fragmentNavigated event.""" - event_class = "browsingContext.fragmentNavigated" +@dataclass +class PrintPageParameters: + """PrintPageParameters.""" - @classmethod - def from_json(cls, json: dict): - if isinstance(json, NavigationInfo): - return json - return NavigationInfo.from_json(json) + height: Any | None = None + width: Any | None = None -class DownloadWillBegin: - """Event class for browsingContext.downloadWillBegin event.""" +@dataclass +class PrintResult: + """PrintResult.""" - event_class = "browsingContext.downloadWillBegin" + data: str | None = None - @classmethod - def from_json(cls, json: dict): - return DownloadWillBeginParams.from_json(json) +@dataclass +class ReloadParameters: + """ReloadParameters.""" -class UserPromptOpened: - """Event class for browsingContext.userPromptOpened event.""" + context: Any | None = None + ignore_cache: bool | None = None + wait: Any | None = None - event_class = "browsingContext.userPromptOpened" - @classmethod - def from_json(cls, json: dict): - return UserPromptOpenedParams.from_json(json) +@dataclass +class SetViewportParameters: + """SetViewportParameters.""" + context: Any | None = None + viewport: Any | None = None + device_pixel_ratio: Any | None = None + user_contexts: list[Any] = field(default_factory=list) -class UserPromptClosed: - """Event class for browsingContext.userPromptClosed event.""" - event_class = "browsingContext.userPromptClosed" +@dataclass +class Viewport: + """Viewport.""" - @classmethod - def from_json(cls, json: dict): - return UserPromptClosedParams.from_json(json) + width: Any | None = None + height: Any | None = None -class HistoryUpdated: - """Event class for browsingContext.historyUpdated event.""" +@dataclass +class TraverseHistoryParameters: + """TraverseHistoryParameters.""" - event_class = "browsingContext.historyUpdated" + context: Any | None = None + delta: Any | None = None - @classmethod - def from_json(cls, json: dict): - return HistoryUpdatedParams.from_json(json) + +@dataclass +class HistoryUpdatedParameters: + """HistoryUpdatedParameters.""" + + context: Any | None = None + timestamp: Any | None = None + url: str | None = None + + +@dataclass +class UserPromptClosedParameters: + """UserPromptClosedParameters.""" + + context: Any | None = None + accepted: bool | None = None + type: Any | None = None + user_text: str | None = None -class DownloadEnd: - """Event class for browsingContext.downloadEnd event.""" +@dataclass +class UserPromptOpenedParameters: + """UserPromptOpenedParameters.""" + + context: Any | None = None + handler: Any | None = None + message: str | None = None + type: Any | None = None + default_value: str | None = None + + +@dataclass +class DownloadWillBeginParams: + """DownloadWillBeginParams.""" + + suggested_filename: str | None = None + +@dataclass +class DownloadCanceledParams: + """DownloadCanceledParams.""" - event_class = "browsingContext.downloadEnd" + status: Any | None = None + +@dataclass +class DownloadParams: + """DownloadParams - fields shared by all download end event variants.""" + + status: str | None = None + context: Any | None = None + navigation: Any | None = None + timestamp: Any | None = None + url: str | None = None + filepath: str | None = None + +@dataclass +class DownloadEndParams: + """DownloadEndParams - params for browsingContext.downloadEnd event.""" + + download_params: "DownloadParams | None" = None @classmethod - def from_json(cls, json: dict): - return DownloadEndParams.from_json(json) + def from_json(cls, params: dict) -> "DownloadEndParams": + """Deserialize from BiDi wire-level params dict.""" + dp = DownloadParams( + status=params.get("status"), + context=params.get("context"), + navigation=params.get("navigation"), + timestamp=params.get("timestamp"), + url=params.get("url"), + filepath=params.get("filepath"), + ) + return cls(download_params=dp) + +# BiDi Event Name to Parameter Type Mapping +EVENT_NAME_MAPPING = { + "context_created": "browsingContext.contextCreated", + "context_destroyed": "browsingContext.contextDestroyed", + "navigation_started": "browsingContext.navigationStarted", + "fragment_navigated": "browsingContext.fragmentNavigated", + "history_updated": "browsingContext.historyUpdated", + "dom_content_loaded": "browsingContext.domContentLoaded", + "load": "browsingContext.load", + "download_will_begin": "browsingContext.downloadWillBegin", + "download_end": "browsingContext.downloadEnd", + "navigation_aborted": "browsingContext.navigationAborted", + "navigation_committed": "browsingContext.navigationCommitted", + "navigation_failed": "browsingContext.navigationFailed", + "user_prompt_closed": "browsingContext.userPromptClosed", + "user_prompt_opened": "browsingContext.userPromptOpened", +} + +def _deserialize_info_list(items: list) -> list | None: + """Recursively deserialize a list of dicts to Info objects. + + Args: + items: List of dicts from the API response + + Returns: + List of Info objects with properly nested children, or None if empty + """ + if not items or not isinstance(items, list): + return None + + result = [] + for item in items: + if isinstance(item, dict): + # Recursively deserialize children only if the key exists in response + children_list = None + if "children" in item: + children_list = _deserialize_info_list(item.get("children", [])) + info = Info( + children=children_list, + client_window=item.get("clientWindow"), + context=item.get("context"), + original_opener=item.get("originalOpener"), + url=item.get("url"), + user_context=item.get("userContext"), + parent=item.get("parent"), + ) + result.append(info) + return result if result else None + + @dataclass class EventConfig: + """Configuration for a BiDi event.""" event_key: str bidi_event: str event_class: type +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. + + Args: + params: Raw BiDi event params with camelCase keys. + + Returns: + An instance of the dataclass, or the raw dict on failure. + """ + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) + + class _EventManager: - """Class to manage event subscriptions and callbacks for BrowsingContext.""" + """Manages event subscriptions and callbacks.""" def __init__(self, conn, event_configs: dict[str, EventConfig]): self.conn = conn self.event_configs = event_configs self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} self._available_events = ", ".join(sorted(event_configs.keys())) - # Thread safety lock for subscription operations self._subscription_lock = threading.Lock() + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + def validate_event(self, event: str) -> EventConfig: event_config = self.event_configs.get(event) if not event_config: @@ -614,447 +518,471 @@ def validate_event(self, event: str) -> EventConfig: return event_config def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: - """Subscribe to a BiDi event if not already subscribed. - - Args: - bidi_event: The BiDi event name. - contexts: Optional browsing context IDs to subscribe to. - """ + """Subscribe to a BiDi event if not already subscribed.""" with self._subscription_lock: if bidi_event not in self.subscriptions: session = Session(self.conn) - self.conn.execute(session.subscribe(bidi_event, browsing_contexts=contexts)) - self.subscriptions[bidi_event] = [] + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } def unsubscribe_from_event(self, bidi_event: str) -> None: - """Unsubscribe from a BiDi event if no more callbacks exist. - - Args: - bidi_event: The BiDi event name. - """ + """Unsubscribe from a BiDi event if no more callbacks exist.""" with self._subscription_lock: - callback_list = self.subscriptions.get(bidi_event) - if callback_list is not None and not callback_list: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: session = Session(self.conn) - self.conn.execute(session.unsubscribe(bidi_event)) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) del self.subscriptions[bidi_event] def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: with self._subscription_lock: - self.subscriptions[bidi_event].append(callback_id) + self.subscriptions[bidi_event]["callbacks"].append(callback_id) def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: with self._subscription_lock: - callback_list = self.subscriptions.get(bidi_event) - if callback_list and callback_id in callback_list: - callback_list.remove(callback_id) + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: event_config = self.validate_event(event) - - callback_id = self.conn.add_callback(event_config.event_class, callback) - - # Subscribe to the event if needed + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) self.subscribe_to_event(event_config.bidi_event, contexts) - - # Track the callback self.add_callback_to_tracking(event_config.bidi_event, callback_id) - return callback_id def remove_event_handler(self, event: str, callback_id: int) -> None: event_config = self.validate_event(event) - - # Remove the callback from the connection - self.conn.remove_callback(event_config.event_class, callback_id) - - # Remove from tracking collections + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) self.remove_callback_from_tracking(event_config.bidi_event, callback_id) - - # Unsubscribe if no more callbacks exist self.unsubscribe_from_event(event_config.bidi_event) def clear_event_handlers(self) -> None: - """Clear all event handlers from the browsing context.""" + """Clear all event handlers.""" with self._subscription_lock: if not self.subscriptions: return - session = Session(self.conn) - - for bidi_event, callback_ids in list(self.subscriptions.items()): - event_class = self._bidi_to_class.get(bidi_event) - if event_class: - # Remove all callbacks for this event - for callback_id in callback_ids: - self.conn.remove_callback(event_class, callback_id) - - self.conn.execute(session.unsubscribe(bidi_event)) - + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) self.subscriptions.clear() -class BrowsingContext: - """BiDi implementation of the browsingContext module.""" - - EVENT_CONFIGS = { - "context_created": EventConfig("context_created", "browsingContext.contextCreated", ContextCreated), - "context_destroyed": EventConfig("context_destroyed", "browsingContext.contextDestroyed", ContextDestroyed), - "dom_content_loaded": EventConfig("dom_content_loaded", "browsingContext.domContentLoaded", DomContentLoaded), - "download_end": EventConfig("download_end", "browsingContext.downloadEnd", DownloadEnd), - "download_will_begin": EventConfig( - "download_will_begin", "browsingContext.downloadWillBegin", DownloadWillBegin - ), - "fragment_navigated": EventConfig("fragment_navigated", "browsingContext.fragmentNavigated", FragmentNavigated), - "history_updated": EventConfig("history_updated", "browsingContext.historyUpdated", HistoryUpdated), - "load": EventConfig("load", "browsingContext.load", Load), - "navigation_aborted": EventConfig("navigation_aborted", "browsingContext.navigationAborted", NavigationAborted), - "navigation_committed": EventConfig( - "navigation_committed", "browsingContext.navigationCommitted", NavigationCommitted - ), - "navigation_failed": EventConfig("navigation_failed", "browsingContext.navigationFailed", NavigationFailed), - "navigation_started": EventConfig("navigation_started", "browsingContext.navigationStarted", NavigationStarted), - "user_prompt_closed": EventConfig("user_prompt_closed", "browsingContext.userPromptClosed", UserPromptClosed), - "user_prompt_opened": EventConfig("user_prompt_opened", "browsingContext.userPromptOpened", UserPromptOpened), - } - - def __init__(self, conn): - self.conn = conn - self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - @classmethod - def get_event_names(cls) -> list[str]: - """Get a list of all available event names. - Returns: - A list of event names that can be used with event handlers. - """ - return list(cls.EVENT_CONFIGS.keys()) +class BrowsingContext: + """WebDriver BiDi browsingContext module.""" - def activate(self, context: str) -> None: - """Activates and focuses the given top-level traversable. + EVENT_CONFIGS: dict[str, EventConfig] = {} + def __init__(self, conn) -> None: + self._conn = conn + self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - Args: - context: The browsing context ID to activate. + def activate(self, context: Any | None = None): + """Execute browsingContext.activate.""" + if context is None: + raise TypeError("activate() missing required argument: {{snake_param!r}}") - Raises: - Exception: If the browsing context is not a top-level traversable. - """ - params = {"context": context} - self.conn.execute(command_builder("browsingContext.activate", params)) + params = { + "context": context, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.activate", params) + result = self._conn.execute(cmd) + return result def capture_screenshot( self, - context: str, - origin: str = "viewport", - format: dict | None = None, - clip: dict | None = None, - ) -> str: - """Captures an image of the given navigable, and returns it as a Base64-encoded string. - - Args: - context: The browsing context ID to capture. - origin: The origin of the screenshot, either "viewport" or "document". - format: The format of the screenshot. - clip: The clip rectangle of the screenshot. - - Returns: - The Base64-encoded screenshot. - """ - params: dict[str, Any] = {"context": context, "origin": origin} - if format is not None: - params["format"] = format - if clip is not None: - params["clip"] = clip - - result = self.conn.execute(command_builder("browsingContext.captureScreenshot", params)) - return result["data"] + context: str | None = None, + format: Any | None = None, + clip: Any | None = None, + origin: str | None = None, + ): + """Execute browsingContext.captureScreenshot.""" + if context is None: + raise TypeError("capture_screenshot() missing required argument: {{snake_param!r}}") - def close(self, context: str, prompt_unload: bool = False) -> None: - """Closes a top-level traversable. + params = { + "context": context, + "format": format, + "clip": clip, + "origin": origin, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.captureScreenshot", params) + result = self._conn.execute(cmd) + if result and "data" in result: + extracted = result.get("data") + return extracted + return result - Args: - context: The browsing context ID to close. - prompt_unload: Whether to prompt to unload. + def close(self, context: Any | None = None, prompt_unload: bool | None = None): + """Execute browsingContext.close.""" + if context is None: + raise TypeError("close() missing required argument: {{snake_param!r}}") - Raises: - Exception: If the browsing context is not a top-level traversable. - """ - params = {"context": context, "promptUnload": prompt_unload} - self.conn.execute(command_builder("browsingContext.close", params)) + params = { + "context": context, + "promptUnload": prompt_unload, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.close", params) + result = self._conn.execute(cmd) + return result def create( self, - type: str, - reference_context: str | None = None, - background: bool = False, - user_context: str | None = None, - ) -> str: - """Creates a new navigable, either in a new tab or in a new window, and returns its navigable id. - - Args: - type: The type of the new navigable, either "tab" or "window". - reference_context: The reference browsing context ID. - background: Whether to create the new navigable in the background. - user_context: The user context ID. - - Returns: - The browsing context ID of the created navigable. - """ - params: dict[str, Any] = {"type": type} - if reference_context is not None: - params["referenceContext"] = reference_context - if background is not None: - params["background"] = background - if user_context is not None: - params["userContext"] = user_context - - result = self.conn.execute(command_builder("browsingContext.create", params)) - return result["context"] - - def get_tree( - self, - max_depth: int | None = None, - root: str | None = None, - ) -> list[BrowsingContextInfo]: - """Get a tree of all descendent navigables including the given parent itself. - - Returns a tree of all descendent navigables including the given parent itself, or all top-level contexts - when no parent is provided. - - Args: - max_depth: The maximum depth of the tree. - root: The root browsing context ID. - - Returns: - A list of browsing context information. - """ - params: dict[str, Any] = {} - if max_depth is not None: - params["maxDepth"] = max_depth - if root is not None: - params["root"] = root - - result = self.conn.execute(command_builder("browsingContext.getTree", params)) - return [BrowsingContextInfo.from_json(context) for context in result["contexts"]] + type: Any | None = None, + reference_context: Any | None = None, + background: bool | None = None, + user_context: Any | None = None, + ): + """Execute browsingContext.create.""" + if type is None: + raise TypeError("create() missing required argument: {{snake_param!r}}") - def handle_user_prompt( - self, - context: str, - accept: bool | None = None, - user_text: str | None = None, - ) -> None: - """Allows closing an open prompt. + params = { + "type": type, + "referenceContext": reference_context, + "background": background, + "userContext": user_context, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.create", params) + result = self._conn.execute(cmd) + if result and "context" in result: + extracted = result.get("context") + return extracted + return result - Args: - context: The browsing context ID. - accept: Whether to accept the prompt. - user_text: The text to enter in the prompt. - """ - params: dict[str, Any] = {"context": context} - if accept is not None: - params["accept"] = accept - if user_text is not None: - params["userText"] = user_text + def get_tree(self, max_depth: Any | None = None, root: Any | None = None): + """Execute browsingContext.getTree.""" + params = { + "maxDepth": max_depth, + "root": root, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.getTree", params) + result = self._conn.execute(cmd) + if result and "contexts" in result: + items = result.get("contexts", []) + return [ + Info( + children=_deserialize_info_list(item.get("children", [])), + client_window=item.get("clientWindow"), + context=item.get("context"), + original_opener=item.get("originalOpener"), + url=item.get("url"), + user_context=item.get("userContext"), + parent=item.get("parent") + ) + for item in items + if isinstance(item, dict) + ] + return [] + + def handle_user_prompt(self, context: Any | None = None, accept: bool | None = None, user_text: Any | None = None): + """Execute browsingContext.handleUserPrompt.""" + if context is None: + raise TypeError("handle_user_prompt() missing required argument: {{snake_param!r}}") - self.conn.execute(command_builder("browsingContext.handleUserPrompt", params)) + params = { + "context": context, + "accept": accept, + "userText": user_text, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.handleUserPrompt", params) + result = self._conn.execute(cmd) + return result def locate_nodes( self, - context: str, - locator: dict, + context: str | None = None, + locator: Any | None = None, + serialization_options: Any | None = None, + start_nodes: Any | None = None, max_node_count: int | None = None, - serialization_options: dict | None = None, - start_nodes: list[dict] | None = None, - ) -> list[dict]: - """Returns a list of all nodes matching the specified locator. - - Args: - context: The browsing context ID. - locator: The locator to use. - max_node_count: The maximum number of nodes to return. - serialization_options: The serialization options. - start_nodes: The start nodes. - - Returns: - A list of nodes. - """ - params: dict[str, Any] = {"context": context, "locator": locator} - if max_node_count is not None: - params["maxNodeCount"] = max_node_count - if serialization_options is not None: - params["serializationOptions"] = serialization_options - if start_nodes is not None: - params["startNodes"] = start_nodes - - result = self.conn.execute(command_builder("browsingContext.locateNodes", params)) - return result["nodes"] - - def navigate( - self, - context: str, - url: str, - wait: str | None = None, - ) -> dict: - """Navigates a navigable to the given URL. + ): + """Execute browsingContext.locateNodes.""" + if context is None: + raise TypeError("locate_nodes() missing required argument: {{snake_param!r}}") + if locator is None: + raise TypeError("locate_nodes() missing required argument: {{snake_param!r}}") - Args: - context: The browsing context ID. - url: The URL to navigate to. - wait: The readiness state to wait for. + params = { + "context": context, + "locator": locator, + "serializationOptions": serialization_options, + "startNodes": start_nodes, + "maxNodeCount": max_node_count, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.locateNodes", params) + result = self._conn.execute(cmd) + if result and "nodes" in result: + extracted = result.get("nodes") + return extracted + return result - Returns: - A dictionary containing the navigation result. - """ - params = {"context": context, "url": url} - if wait is not None: - params["wait"] = wait + def navigate(self, context: Any | None = None, url: Any | None = None, wait: Any | None = None): + """Execute browsingContext.navigate.""" + if context is None: + raise TypeError("navigate() missing required argument: {{snake_param!r}}") + if url is None: + raise TypeError("navigate() missing required argument: {{snake_param!r}}") - result = self.conn.execute(command_builder("browsingContext.navigate", params)) + params = { + "context": context, + "url": url, + "wait": wait, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.navigate", params) + result = self._conn.execute(cmd) return result def print( self, - context: str, - background: bool = False, - margin: dict | None = None, - orientation: str = "portrait", - page: dict | None = None, - page_ranges: list[int | str] | None = None, - scale: float = 1.0, - shrink_to_fit: bool = True, - ) -> str: - """Create a paginated PDF representation of the document as a Base64-encoded string. - - Args: - context: The browsing context ID. - background: Whether to include the background. - margin: The margin parameters. - orientation: The orientation, either "portrait" or "landscape". - page: The page parameters. - page_ranges: The page ranges. - scale: The scale. - shrink_to_fit: Whether to shrink to fit. + context: Any | None = None, + background: bool | None = None, + margin: Any | None = None, + page: Any | None = None, + scale: Any | None = None, + shrink_to_fit: bool | None = None, + ): + """Execute browsingContext.print.""" + if context is None: + raise TypeError("print() missing required argument: {{snake_param!r}}") - Returns: - The Base64-encoded PDF document. - """ params = { "context": context, "background": background, - "orientation": orientation, + "margin": margin, + "page": page, "scale": scale, "shrinkToFit": shrink_to_fit, } - if margin is not None: - params["margin"] = margin - if page is not None: - params["page"] = page - if page_ranges is not None: - params["pageRanges"] = page_ranges - - result = self.conn.execute(command_builder("browsingContext.print", params)) - return result["data"] - - def reload( - self, - context: str, - ignore_cache: bool | None = None, - wait: str | None = None, - ) -> dict: - """Reloads a navigable. + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.print", params) + result = self._conn.execute(cmd) + if result and "data" in result: + extracted = result.get("data") + return extracted + return result - Args: - context: The browsing context ID. - ignore_cache: Whether to ignore the cache. - wait: The readiness state to wait for. + def reload(self, context: Any | None = None, ignore_cache: bool | None = None, wait: Any | None = None): + """Execute browsingContext.reload.""" + if context is None: + raise TypeError("reload() missing required argument: {{snake_param!r}}") - Returns: - A dictionary containing the navigation result. - """ - params: dict[str, Any] = {"context": context} - if ignore_cache is not None: - params["ignoreCache"] = ignore_cache - if wait is not None: - params["wait"] = wait - - result = self.conn.execute(command_builder("browsingContext.reload", params)) + params = { + "context": context, + "ignoreCache": ignore_cache, + "wait": wait, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.reload", params) + result = self._conn.execute(cmd) return result def set_viewport( self, context: str | None = None, - viewport: dict | None | Sentinel = UNDEFINED, - device_pixel_ratio: float | None | Sentinel = UNDEFINED, - user_contexts: list[str] | None = None, - ) -> None: - """Modifies specific viewport characteristics on the given top-level traversable. - - Args: - context: The browsing context ID. - viewport: The viewport parameters - {"width": , "height": } (`None` resets to default). - device_pixel_ratio: The device pixel ratio (`None` resets to default). - user_contexts: The user context IDs. - - Raises: - Exception: If the browsing context is not a top-level traversable - ValueError: If neither `context` nor `user_contexts` is provided - ValueError: If both `context` and `user_contexts` are provided - """ - if context is not None and user_contexts is not None: - raise ValueError("Cannot specify both context and user_contexts") - - if context is None and user_contexts is None: - raise ValueError("Must specify either context or user_contexts") - - params: dict[str, Any] = {} - if context is not None: - params["context"] = context - elif user_contexts is not None: - params["userContexts"] = user_contexts - if viewport is not UNDEFINED: - params["viewport"] = viewport - if device_pixel_ratio is not UNDEFINED: - params["devicePixelRatio"] = device_pixel_ratio - - self.conn.execute(command_builder("browsingContext.setViewport", params)) - - def traverse_history(self, context: str, delta: int) -> dict: - """Traverses the history of a given navigable by a delta. + viewport: Any | None = None, + user_contexts: Any | None = None, + device_pixel_ratio: Any | None = None, + ): + """Execute browsingContext.setViewport.""" + params = { + "context": context, + "viewport": viewport, + "userContexts": user_contexts, + "devicePixelRatio": device_pixel_ratio, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.setViewport", params) + result = self._conn.execute(cmd) + return result - Args: - context: The browsing context ID. - delta: The delta to traverse by. + def traverse_history(self, context: Any | None = None, delta: Any | None = None): + """Execute browsingContext.traverseHistory.""" + if context is None: + raise TypeError("traverse_history() missing required argument: {{snake_param!r}}") + if delta is None: + raise TypeError("traverse_history() missing required argument: {{snake_param!r}}") - Returns: - A dictionary containing the traverse history result. - """ - params = {"context": context, "delta": delta} - result = self.conn.execute(command_builder("browsingContext.traverseHistory", params)) + params = { + "context": context, + "delta": delta, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("browsingContext.traverseHistory", params) + result = self._conn.execute(cmd) return result + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: - """Add an event handler to the browsing context. + """Add an event handler. Args: event: The event to subscribe to. callback: The callback function to execute on event. - contexts: The browsing context IDs to subscribe to. + contexts: The context IDs to subscribe to (optional). Returns: - Callback id. + The callback ID. """ return self._event_manager.add_event_handler(event, callback, contexts) def remove_event_handler(self, event: str, callback_id: int) -> None: - """Remove an event handler from the browsing context. + """Remove an event handler. Args: event: The event to unsubscribe from. - callback_id: The callback id to remove. + callback_id: The callback ID. """ - self._event_manager.remove_event_handler(event, callback_id) + return self._event_manager.remove_event_handler(event, callback_id) def clear_event_handlers(self) -> None: - """Clear all event handlers from the browsing context.""" - self._event_manager.clear_event_handlers() + """Clear all event handlers.""" + return self._event_manager.clear_event_handlers() + +# Event Info Type Aliases +# Event: browsingContext.contextCreated +ContextCreated = globals().get('Info', dict) # Fallback to dict if type not defined + +# Event: browsingContext.contextDestroyed +ContextDestroyed = globals().get('Info', dict) # Fallback to dict if type not defined + +# Event: browsingContext.navigationStarted +NavigationStarted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined + +# Event: browsingContext.fragmentNavigated +FragmentNavigated = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined + +# Event: browsingContext.historyUpdated +HistoryUpdated = globals().get('HistoryUpdatedParameters', dict) # Fallback to dict if type not defined + +# Event: browsingContext.domContentLoaded +DomContentLoaded = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined + +# Event: browsingContext.load +Load = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined + +# Event: browsingContext.downloadWillBegin +DownloadWillBegin = globals().get('DownloadWillBeginParams', dict) # Fallback to dict if type not defined + +# Event: browsingContext.downloadEnd +DownloadEnd = globals().get('DownloadEndParams', dict) # Fallback to dict if type not defined + +# Event: browsingContext.navigationAborted +NavigationAborted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined + +# Event: browsingContext.navigationCommitted +NavigationCommitted = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined + +# Event: browsingContext.navigationFailed +NavigationFailed = globals().get('BaseNavigationInfo', dict) # Fallback to dict if type not defined + +# Event: browsingContext.userPromptClosed +UserPromptClosed = globals().get('UserPromptClosedParameters', dict) # Fallback to dict if type not defined + +# Event: browsingContext.userPromptOpened +UserPromptOpened = globals().get('UserPromptOpenedParameters', dict) # Fallback to dict if type not defined + + +# Populate EVENT_CONFIGS with event configuration mappings +_globals = globals() +BrowsingContext.EVENT_CONFIGS = { + "context_created": EventConfig( + "context_created", + "browsingContext.contextCreated", + _globals.get("ContextCreated", dict) if _globals.get("ContextCreated") else dict, + ), + "context_destroyed": EventConfig( + "context_destroyed", + "browsingContext.contextDestroyed", + _globals.get("ContextDestroyed", dict) if _globals.get("ContextDestroyed") else dict, + ), + "navigation_started": EventConfig( + "navigation_started", + "browsingContext.navigationStarted", + _globals.get("NavigationStarted", dict) if _globals.get("NavigationStarted") else dict, + ), + "fragment_navigated": EventConfig( + "fragment_navigated", + "browsingContext.fragmentNavigated", + _globals.get("FragmentNavigated", dict) if _globals.get("FragmentNavigated") else dict, + ), + "history_updated": EventConfig( + "history_updated", + "browsingContext.historyUpdated", + _globals.get("HistoryUpdated", dict) if _globals.get("HistoryUpdated") else dict, + ), + "dom_content_loaded": EventConfig( + "dom_content_loaded", + "browsingContext.domContentLoaded", + _globals.get("DomContentLoaded", dict) if _globals.get("DomContentLoaded") else dict, + ), + "load": EventConfig("load", "browsingContext.load", _globals.get("Load", dict) if _globals.get("Load") else dict), + "download_will_begin": EventConfig( + "download_will_begin", + "browsingContext.downloadWillBegin", + _globals.get("DownloadWillBegin", dict) if _globals.get("DownloadWillBegin") else dict, + ), + "download_end": EventConfig( + "download_end", + "browsingContext.downloadEnd", + _globals.get("DownloadEnd", dict) if _globals.get("DownloadEnd") else dict, + ), + "navigation_aborted": EventConfig( + "navigation_aborted", + "browsingContext.navigationAborted", + _globals.get("NavigationAborted", dict) if _globals.get("NavigationAborted") else dict, + ), + "navigation_committed": EventConfig( + "navigation_committed", + "browsingContext.navigationCommitted", + _globals.get("NavigationCommitted", dict) if _globals.get("NavigationCommitted") else dict, + ), + "navigation_failed": EventConfig( + "navigation_failed", + "browsingContext.navigationFailed", + _globals.get("NavigationFailed", dict) if _globals.get("NavigationFailed") else dict, + ), + "user_prompt_closed": EventConfig( + "user_prompt_closed", + "browsingContext.userPromptClosed", + _globals.get("UserPromptClosed", dict) if _globals.get("UserPromptClosed") else dict, + ), + "user_prompt_opened": EventConfig( + "user_prompt_opened", + "browsingContext.userPromptOpened", + _globals.get("UserPromptOpened", dict) if _globals.get("UserPromptOpened") else dict, + ), +} diff --git a/py/selenium/webdriver/common/bidi/common.py b/py/selenium/webdriver/common/bidi/common.py index 0f57d07e5f0d4..59e8afd93ab2e 100644 --- a/py/selenium/webdriver/common/bidi/common.py +++ b/py/selenium/webdriver/common/bidi/common.py @@ -15,22 +15,28 @@ # specific language governing permissions and limitations # under the License. +"""Common utilities for BiDi command construction.""" + +from __future__ import annotations + from collections.abc import Generator +from typing import Any -def command_builder(method: str, params: dict | None = None) -> Generator[dict, dict, dict]: - """Build a command iterator to send to the BiDi protocol. +def command_builder( + method: str, params: dict[str, Any] +) -> Generator[dict[str, Any], Any, Any]: + """Build a BiDi command generator. Args: - method: The method to execute. - params: The parameters to pass to the method. Default is None. + method: The BiDi method name (e.g., "session.status", "browser.close") + params: The parameters for the command + + Yields: + A dictionary representing the BiDi command Returns: - The response from the command execution. + The result from the BiDi command execution """ - if params is None: - params = {} - - command = {"method": method, "params": params} - cmd = yield command - return cmd + result = yield {"method": method, "params": params} + return result diff --git a/py/selenium/webdriver/common/bidi/emulation.py b/py/selenium/webdriver/common/bidi/emulation.py index a6acaefe89b83..03347a0a85c04 100644 --- a/py/selenium/webdriver/common/bidi/emulation.py +++ b/py/selenium/webdriver/common/bidi/emulation.py @@ -1,39 +1,33 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# WebDriver BiDi module: emulation from __future__ import annotations -from enum import Enum -from typing import TYPE_CHECKING, Any, TypeVar +from typing import Any +from .common import command_builder +from dataclasses import dataclass +from dataclasses import field -from selenium.webdriver.common.bidi.common import command_builder -if TYPE_CHECKING: - from selenium.webdriver.remote.websocket_connection import WebSocketConnection +class ForcedColorsModeTheme: + """ForcedColorsModeTheme.""" + LIGHT = "light" + DARK = "dark" -class ScreenOrientationNatural(Enum): - """Natural screen orientation.""" + +class ScreenOrientationNatural: + """ScreenOrientationNatural.""" PORTRAIT = "portrait" LANDSCAPE = "landscape" -class ScreenOrientationType(Enum): - """Screen orientation type.""" +class ScreenOrientationType: + """ScreenOrientationType.""" PORTRAIT_PRIMARY = "portrait-primary" PORTRAIT_SECONDARY = "portrait-secondary" @@ -41,484 +35,463 @@ class ScreenOrientationType(Enum): LANDSCAPE_SECONDARY = "landscape-secondary" -E = TypeVar("E", ScreenOrientationNatural, ScreenOrientationType) +@dataclass +class SetForcedColorsModeThemeOverrideParameters: + """SetForcedColorsModeThemeOverrideParameters.""" + theme: Any | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) -def _convert_to_enum(value: E | str, enum_class: type[E]) -> E: - if isinstance(value, enum_class): - return value - assert isinstance(value, str) - try: - return enum_class(value.lower()) - except ValueError: - raise ValueError(f"Invalid orientation: {value}") +@dataclass +class SetGeolocationOverrideParameters: + """SetGeolocationOverrideParameters.""" -class ScreenOrientation: - """Represents screen orientation configuration.""" + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) - def __init__( - self, - natural: ScreenOrientationNatural | str, - type: ScreenOrientationType | str, - ): - """Initialize ScreenOrientation. - Args: - natural: Natural screen orientation ("portrait" or "landscape"). - type: Screen orientation type ("portrait-primary", "portrait-secondary", - "landscape-primary", or "landscape-secondary"). +@dataclass +class GeolocationCoordinates: + """GeolocationCoordinates.""" - Raises: - ValueError: If natural or type values are invalid. - """ - # handle string values - self.natural = _convert_to_enum(natural, ScreenOrientationNatural) - self.type = _convert_to_enum(type, ScreenOrientationType) - - def to_dict(self) -> dict[str, str]: - return { - "natural": self.natural.value, - "type": self.type.value, - } + latitude: Any | None = None + longitude: Any | None = None + accuracy: Any | None = None + altitude: Any | None = None + altitude_accuracy: Any | None = None + heading: Any | None = None + speed: Any | None = None -class GeolocationCoordinates: - """Represents geolocation coordinates.""" +@dataclass +class GeolocationPositionError: + """GeolocationPositionError.""" - def __init__( - self, - latitude: float, - longitude: float, - accuracy: float = 1.0, - altitude: float | None = None, - altitude_accuracy: float | None = None, - heading: float | None = None, - speed: float | None = None, - ): - """Initialize GeolocationCoordinates. + type: str = field(default="positionUnavailable", init=False) - Args: - latitude: Latitude coordinate (-90.0 to 90.0). - longitude: Longitude coordinate (-180.0 to 180.0). - accuracy: Accuracy in meters (>= 0.0), defaults to 1.0. - altitude: Altitude in meters or None, defaults to None. - altitude_accuracy: Altitude accuracy in meters (>= 0.0) or None, defaults to None. - heading: Heading in degrees (0.0 to 360.0) or None, defaults to None. - speed: Speed in meters per second (>= 0.0) or None, defaults to None. - - Raises: - ValueError: If coordinates are out of valid range or if altitude_accuracy is provided without altitude. - """ - self.latitude = latitude - self.longitude = longitude - self.accuracy = accuracy - self.altitude = altitude - self.altitude_accuracy = altitude_accuracy - self.heading = heading - self.speed = speed - - @property - def latitude(self) -> float: - return self._latitude - - @latitude.setter - def latitude(self, value: float) -> None: - if not (-90.0 <= value <= 90.0): - raise ValueError("latitude must be between -90.0 and 90.0") - self._latitude = value - - @property - def longitude(self) -> float: - return self._longitude - - @longitude.setter - def longitude(self, value: float) -> None: - if not (-180.0 <= value <= 180.0): - raise ValueError("longitude must be between -180.0 and 180.0") - self._longitude = value - - @property - def accuracy(self) -> float: - return self._accuracy - - @accuracy.setter - def accuracy(self, value: float) -> None: - if value < 0.0: - raise ValueError("accuracy must be >= 0.0") - self._accuracy = value - - @property - def altitude(self) -> float | None: - return self._altitude - - @altitude.setter - def altitude(self, value: float | None) -> None: - self._altitude = value - - @property - def altitude_accuracy(self) -> float | None: - return self._altitude_accuracy - - @altitude_accuracy.setter - def altitude_accuracy(self, value: float | None) -> None: - if value is not None and self.altitude is None: - raise ValueError("altitude_accuracy cannot be set without altitude") - if value is not None and value < 0.0: - raise ValueError("altitude_accuracy must be >= 0.0") - self._altitude_accuracy = value - - @property - def heading(self) -> float | None: - return self._heading - - @heading.setter - def heading(self, value: float | None) -> None: - if value is not None and not (0.0 <= value < 360.0): - raise ValueError("heading must be between 0.0 and 360.0") - self._heading = value - - @property - def speed(self) -> float | None: - return self._speed - - @speed.setter - def speed(self, value: float | None) -> None: - if value is not None and value < 0.0: - raise ValueError("speed must be >= 0.0") - self._speed = value - - def to_dict(self) -> dict[str, float | None]: - result: dict[str, float | None] = { - "latitude": self.latitude, - "longitude": self.longitude, - "accuracy": self.accuracy, - } - if self.altitude is not None: - result["altitude"] = self.altitude +@dataclass +class SetLocaleOverrideParameters: + """SetLocaleOverrideParameters.""" - if self.altitude_accuracy is not None: - result["altitudeAccuracy"] = self.altitude_accuracy + locale: Any | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) - if self.heading is not None: - result["heading"] = self.heading - if self.speed is not None: - result["speed"] = self.speed +@dataclass +class setNetworkConditionsParameters: + """setNetworkConditionsParameters.""" - return result + network_conditions: Any | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) -class GeolocationPositionError: - """Represents a geolocation position error.""" +@dataclass +class NetworkConditionsOffline: + """NetworkConditionsOffline.""" - TYPE_POSITION_UNAVAILABLE = "positionUnavailable" + type: str = field(default="offline", init=False) - def __init__(self, type: str = TYPE_POSITION_UNAVAILABLE): - if type != self.TYPE_POSITION_UNAVAILABLE: - raise ValueError(f'type must be "{self.TYPE_POSITION_UNAVAILABLE}"') - self.type = type - def to_dict(self) -> dict[str, str]: - return {"type": self.type} +@dataclass +class ScreenArea: + """ScreenArea.""" + width: Any | None = None + height: Any | None = None -class Emulation: - """BiDi implementation of the emulation module.""" - def __init__(self, conn: WebSocketConnection) -> None: - self.conn = conn +@dataclass +class SetScreenSettingsOverrideParameters: + """SetScreenSettingsOverrideParameters.""" - def set_geolocation_override( - self, - coordinates: GeolocationCoordinates | None = None, - error: GeolocationPositionError | None = None, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set geolocation override for the given contexts or user contexts. + screen_area: Any | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) - Args: - coordinates: Geolocation coordinates to emulate, or None. - error: Geolocation error to emulate, or None. - contexts: List of browsing context IDs to apply the override to. - user_contexts: List of user context IDs to apply the override to. - - Raises: - ValueError: If both coordinates and error are provided, or if both contexts - and user_contexts are provided, or if neither contexts nor - user_contexts are provided. - """ - if coordinates is not None and error is not None: - raise ValueError("Cannot specify both coordinates and error") - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and userContexts") +@dataclass +class ScreenOrientation: + """ScreenOrientation.""" - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or userContexts") + natural: Any | None = None + type: Any | None = None - params: dict[str, Any] = {} - if coordinates is not None: - params["coordinates"] = coordinates.to_dict() - elif error is not None: - params["error"] = error.to_dict() +@dataclass +class SetScreenOrientationOverrideParameters: + """SetScreenOrientationOverrideParameters.""" - if contexts is not None: - params["contexts"] = contexts - elif user_contexts is not None: - params["userContexts"] = user_contexts + screen_orientation: Any | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) - self.conn.execute(command_builder("emulation.setGeolocationOverride", params)) - def set_timezone_override( - self, - timezone: str | None = None, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set timezone override for the given contexts or user contexts. +@dataclass +class SetUserAgentOverrideParameters: + """SetUserAgentOverrideParameters.""" - Args: - timezone: Timezone identifier (IANA timezone name or offset string like '+01:00'), - or None to clear the override. - contexts: List of browsing context IDs to apply the override to. - user_contexts: List of user context IDs to apply the override to. - - Raises: - ValueError: If both contexts and user_contexts are provided, or if neither - contexts nor user_contexts are provided. - """ - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and user_contexts") + user_agent: Any | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or user_contexts") - params: dict[str, Any] = {"timezone": timezone} +@dataclass +class SetViewportMetaOverrideParameters: + """SetViewportMetaOverrideParameters.""" - if contexts is not None: - params["contexts"] = contexts - elif user_contexts is not None: - params["userContexts"] = user_contexts + viewport_meta: Any | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) - self.conn.execute(command_builder("emulation.setTimezoneOverride", params)) - def set_locale_override( - self, - locale: str | None = None, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set locale override for the given contexts or user contexts. +@dataclass +class SetScriptingEnabledParameters: + """SetScriptingEnabledParameters.""" - Args: - locale: Locale string as per BCP 47, or None to clear override. - contexts: List of browsing context IDs to apply the override to. - user_contexts: List of user context IDs to apply the override to. + enabled: Any | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) - Raises: - ValueError: If both contexts and user_contexts are provided, or if neither - contexts nor user_contexts are provided, or if locale is invalid. - """ - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and userContexts") - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or userContexts") +@dataclass +class SetScrollbarTypeOverrideParameters: + """SetScrollbarTypeOverrideParameters.""" - params: dict[str, Any] = {"locale": locale} + scrollbar_type: Any | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) - if contexts is not None: - params["contexts"] = contexts - elif user_contexts is not None: - params["userContexts"] = user_contexts - self.conn.execute(command_builder("emulation.setLocaleOverride", params)) +@dataclass +class SetTimezoneOverrideParameters: + """SetTimezoneOverrideParameters.""" - def set_scripting_enabled( - self, - enabled: bool | None = False, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set scripting enabled override for the given contexts or user contexts. + timezone: Any | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) - Args: - enabled: False to disable scripting, None to clear the override. - Note: Only emulation of disabled JavaScript is supported. - contexts: List of browsing context IDs to apply the override to. - user_contexts: List of user context IDs to apply the override to. - - Raises: - ValueError: If both contexts and user_contexts are provided, or if neither - contexts nor user_contexts are provided, or if enabled is True. - """ - if enabled: - raise ValueError("Only emulation of disabled JavaScript is supported (enabled must be False or None)") - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and userContexts") +@dataclass +class SetTouchOverrideParameters: + """SetTouchOverrideParameters.""" - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or userContexts") + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) - params: dict[str, Any] = {"enabled": enabled} - if contexts is not None: - params["contexts"] = contexts - elif user_contexts is not None: - params["userContexts"] = user_contexts +class Emulation: + """WebDriver BiDi emulation module.""" - self.conn.execute(command_builder("emulation.setScriptingEnabled", params)) + def __init__(self, conn) -> None: + self._conn = conn - def set_screen_orientation_override( + def set_forced_colors_mode_theme_override( self, - screen_orientation: ScreenOrientation | None = None, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set screen orientation override for the given contexts or user contexts. + theme: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute emulation.setForcedColorsModeThemeOverride.""" + if theme is None: + raise TypeError("set_forced_colors_mode_theme_override() missing required argument: {{snake_param!r}}") + + params = { + "theme": theme, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setForcedColorsModeThemeOverride", params) + result = self._conn.execute(cmd) + return result - Args: - screen_orientation: ScreenOrientation object to emulate, or None to clear the override. - contexts: List of browsing context IDs to apply the override to. - user_contexts: List of user context IDs to apply the override to. + def set_locale_override( + self, + locale: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute emulation.setLocaleOverride.""" + if locale is None: + raise TypeError("set_locale_override() missing required argument: {{snake_param!r}}") + + params = { + "locale": locale, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setLocaleOverride", params) + result = self._conn.execute(cmd) + return result - Raises: - ValueError: If both contexts and user_contexts are provided, or if neither - contexts nor user_contexts are provided. - """ - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and userContexts") + def set_screen_settings_override( + self, + screen_area: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute emulation.setScreenSettingsOverride.""" + if screen_area is None: + raise TypeError("set_screen_settings_override() missing required argument: {{snake_param!r}}") + + params = { + "screenArea": screen_area, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setScreenSettingsOverride", params) + result = self._conn.execute(cmd) + return result - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or userContexts") + def set_viewport_meta_override( + self, + viewport_meta: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute emulation.setViewportMetaOverride.""" + if viewport_meta is None: + raise TypeError("set_viewport_meta_override() missing required argument: {{snake_param!r}}") + + params = { + "viewportMeta": viewport_meta, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setViewportMetaOverride", params) + result = self._conn.execute(cmd) + return result - params: dict[str, Any] = { - "screenOrientation": screen_orientation.to_dict() if screen_orientation is not None else None + def set_scrollbar_type_override( + self, + scrollbar_type: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute emulation.setScrollbarTypeOverride.""" + if scrollbar_type is None: + raise TypeError("set_scrollbar_type_override() missing required argument: {{snake_param!r}}") + + params = { + "scrollbarType": scrollbar_type, + "contexts": contexts, + "userContexts": user_contexts, } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setScrollbarTypeOverride", params) + result = self._conn.execute(cmd) + return result + + def set_touch_override(self, contexts: list[Any] | None = None, user_contexts: list[Any] | None = None): + """Execute emulation.setTouchOverride.""" + params = { + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("emulation.setTouchOverride", params) + result = self._conn.execute(cmd) + return result + def set_geolocation_override( + self, + coordinates=None, + error=None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute emulation.setGeolocationOverride. + + Sets or clears the geolocation override for specified browsing or user contexts. + + Args: + coordinates: A GeolocationCoordinates instance (or dict) to override the + position, or ``None`` to clear a previously-set override. + error: A GeolocationPositionError instance (or dict) to simulate a + position-unavailable error. Mutually exclusive with *coordinates*. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + params: dict[str, Any] = {} + if coordinates is not None: + if isinstance(coordinates, dict): + coords_dict = coordinates + else: + coords_dict = {} + if coordinates.latitude is not None: + coords_dict["latitude"] = coordinates.latitude + if coordinates.longitude is not None: + coords_dict["longitude"] = coordinates.longitude + if coordinates.accuracy is not None: + coords_dict["accuracy"] = coordinates.accuracy + if coordinates.altitude is not None: + coords_dict["altitude"] = coordinates.altitude + if coordinates.altitude_accuracy is not None: + coords_dict["altitudeAccuracy"] = coordinates.altitude_accuracy + if coordinates.heading is not None: + coords_dict["heading"] = coordinates.heading + if coordinates.speed is not None: + coords_dict["speed"] = coordinates.speed + params["coordinates"] = coords_dict + if error is not None: + if isinstance(error, dict): + params["error"] = error + else: + params["error"] = { + "type": error.type if error.type is not None else "positionUnavailable" + } if contexts is not None: params["contexts"] = contexts - elif user_contexts is not None: + if user_contexts is not None: params["userContexts"] = user_contexts + cmd = command_builder("emulation.setGeolocationOverride", params) + result = self._conn.execute(cmd) + return result + def set_timezone_override( + self, + timezone=None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute emulation.setTimezoneOverride. - self.conn.execute(command_builder("emulation.setScreenOrientationOverride", params)) + Sets or clears the timezone override for specified browsing or user contexts. + Pass ``timezone=None`` (or omit it) to clear a previously-set override. - def set_user_agent_override( + Args: + timezone: IANA timezone string (e.g. ``"America/New_York"``) or ``None`` + to clear the override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ + params: dict[str, Any] = {"timezone": timezone} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setTimezoneOverride", params) + return self._conn.execute(cmd) + def set_scripting_enabled( self, - user_agent: str | None = None, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set user agent override for the given contexts or user contexts. + enabled=None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute emulation.setScriptingEnabled. - Args: - user_agent: User agent string to emulate, or None to clear the override. - contexts: List of browsing context IDs to apply the override to. - user_contexts: List of user context IDs to apply the override to. + Enables or disables scripting for specified browsing or user contexts. + Pass ``enabled=None`` to restore the default behaviour. - Raises: - ValueError: If both contexts and user_contexts are provided, or if neither - contexts nor user_contexts are provided. + Args: + enabled: ``True`` to enable scripting, ``False`` to disable it, or + ``None`` to clear the override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. """ - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and user_contexts") + params: dict[str, Any] = {"enabled": enabled} + if contexts is not None: + params["contexts"] = contexts + if user_contexts is not None: + params["userContexts"] = user_contexts + cmd = command_builder("emulation.setScriptingEnabled", params) + return self._conn.execute(cmd) + def set_user_agent_override( + self, + user_agent=None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute emulation.setUserAgentOverride. - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or user_contexts") + Overrides the User-Agent string for specified browsing or user contexts. + Pass ``user_agent=None`` to clear a previously-set override. + Args: + user_agent: Custom User-Agent string, or ``None`` to clear the override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. + """ params: dict[str, Any] = {"userAgent": user_agent} - if contexts is not None: params["contexts"] = contexts - elif user_contexts is not None: + if user_contexts is not None: params["userContexts"] = user_contexts - - self.conn.execute(command_builder("emulation.setUserAgentOverride", params)) - - def set_network_conditions( + cmd = command_builder("emulation.setUserAgentOverride", params) + return self._conn.execute(cmd) + def set_screen_orientation_override( self, - offline: bool = False, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set network conditions for the given contexts or user contexts. + screen_orientation=None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute emulation.setScreenOrientationOverride. - Args: - offline: True to emulate offline network conditions, False to clear the override. - contexts: List of browsing context IDs to apply the conditions to. - user_contexts: List of user context IDs to apply the conditions to. + Sets or clears the screen orientation override for specified browsing or + user contexts. - Raises: - ValueError: If both contexts and user_contexts are provided, or if neither - contexts nor user_contexts are provided. + Args: + screen_orientation: A :class:`ScreenOrientation` instance (or dict with + ``natural`` and ``type`` keys) to lock the orientation, or ``None`` + to clear a previously-set override. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. """ - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and user_contexts") - - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or user_contexts") - - params: dict[str, Any] = {} - - if offline: - params["networkConditions"] = {"type": "offline"} + if screen_orientation is None: + so_value = None + elif isinstance(screen_orientation, dict): + so_value = screen_orientation else: - # if offline is False or None, then clear the override - params["networkConditions"] = None - + natural = screen_orientation.natural + orientation_type = screen_orientation.type + so_value = { + "natural": natural.lower() if isinstance(natural, str) else natural, + "type": orientation_type.lower() if isinstance(orientation_type, str) else orientation_type, + } + params: dict[str, Any] = {"screenOrientation": so_value} if contexts is not None: params["contexts"] = contexts - elif user_contexts is not None: + if user_contexts is not None: params["userContexts"] = user_contexts - - self.conn.execute(command_builder("emulation.setNetworkConditions", params)) - - def set_screen_settings_override( + cmd = command_builder("emulation.setScreenOrientationOverride", params) + return self._conn.execute(cmd) + def set_network_conditions( self, - width: int | None = None, - height: int | None = None, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - ) -> None: - """Set screen settings override for the given contexts or user contexts. + network_conditions=None, + offline: bool | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute emulation.setNetworkConditions. + + Sets or clears network condition emulation for specified browsing or user + contexts. Args: - width: Screen width in pixels (>= 0). None to clear the override. - height: Screen height in pixels (>= 0). None to clear the override. - contexts: List of browsing context IDs to apply the override to. - user_contexts: List of user context IDs to apply the override to. - - Raises: - ValueError: If only one of width/height is provided, or if both contexts - and user_contexts are provided, or if neither is provided. + network_conditions: A dict with the raw ``networkConditions`` value + (e.g. ``{"type": "offline"}``), or ``None`` to clear the override. + Mutually exclusive with *offline*. + offline: Convenience bool — ``True`` sets offline conditions, + ``False`` clears them (sends ``null``). When provided, this takes + precedence over *network_conditions*. + contexts: List of browsing context IDs to target. + user_contexts: List of user context IDs to target. """ - if (width is None) != (height is None): - raise ValueError("Must provide both width and height, or neither to clear the override") - - if contexts is not None and user_contexts is not None: - raise ValueError("Cannot specify both contexts and user_contexts") - - if contexts is None and user_contexts is None: - raise ValueError("Must specify either contexts or user_contexts") - - screen_area = None - if width is not None and height is not None: - if not isinstance(width, int) or not isinstance(height, int): - raise ValueError("width and height must be integers") - if width < 0 or height < 0: - raise ValueError("width and height must be >= 0") - screen_area = {"width": width, "height": height} - - params: dict[str, Any] = {"screenArea": screen_area} - + if offline is not None: + nc_value = {"type": "offline"} if offline else None + else: + nc_value = network_conditions + params: dict[str, Any] = {"networkConditions": nc_value} if contexts is not None: params["contexts"] = contexts - elif user_contexts is not None: + if user_contexts is not None: params["userContexts"] = user_contexts - - self.conn.execute(command_builder("emulation.setScreenSettingsOverride", params)) + cmd = command_builder("emulation.setNetworkConditions", params) + return self._conn.execute(cmd) diff --git a/py/selenium/webdriver/common/bidi/input.py b/py/selenium/webdriver/common/bidi/input.py index 270ececaf41a1..44fd3c82c3407 100644 --- a/py/selenium/webdriver/common/bidi/input.py +++ b/py/selenium/webdriver/common/bidi/input.py @@ -1,40 +1,30 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import math -from dataclasses import dataclass, field -from typing import Any +# WebDriver BiDi module: input +from __future__ import annotations -from selenium.webdriver.common.bidi.common import command_builder +from typing import Any +from .common import command_builder +from dataclasses import dataclass +from dataclasses import field +import threading +from collections.abc import Callable from selenium.webdriver.common.bidi.session import Session class PointerType: - """Represents the possible pointer types.""" + """PointerType.""" MOUSE = "mouse" PEN = "pen" TOUCH = "touch" - VALID_TYPES = {MOUSE, PEN, TOUCH} - class Origin: - """Represents the possible origin types.""" + """Origin.""" VIEWPORT = "viewport" POINTER = "pointer" @@ -42,421 +32,444 @@ class Origin: @dataclass class ElementOrigin: - """Represents an element origin for input actions.""" + """ElementOrigin.""" - type: str - element: dict - - def __init__(self, element_reference: dict): - self.type = "element" - self.element = element_reference - - def to_dict(self) -> dict: - """Convert the ElementOrigin to a dictionary.""" - return {"type": self.type, "element": self.element} + type: str = field(default="element", init=False) + element: Any | None = None @dataclass -class PointerParameters: - """Represents pointer parameters for pointer actions.""" +class PerformActionsParameters: + """PerformActionsParameters.""" - pointer_type: str = PointerType.MOUSE - - def __post_init__(self): - if self.pointer_type not in PointerType.VALID_TYPES: - raise ValueError(f"Invalid pointer type: {self.pointer_type}. Must be one of {PointerType.VALID_TYPES}") - - def to_dict(self) -> dict: - """Convert the PointerParameters to a dictionary.""" - return {"pointerType": self.pointer_type} + context: Any | None = None + actions: list[Any] = field(default_factory=list) @dataclass -class PointerCommonProperties: - """Common properties for pointer actions.""" - - width: int = 1 - height: int = 1 - pressure: float = 0.0 - tangential_pressure: float = 0.0 - twist: int = 0 - altitude_angle: float = 0.0 - azimuth_angle: float = 0.0 - - def __post_init__(self): - if self.width < 1: - raise ValueError("width must be at least 1") - if self.height < 1: - raise ValueError("height must be at least 1") - if not (0.0 <= self.pressure <= 1.0): - raise ValueError("pressure must be between 0.0 and 1.0") - if not (0.0 <= self.tangential_pressure <= 1.0): - raise ValueError("tangential_pressure must be between 0.0 and 1.0") - if not (0 <= self.twist <= 359): - raise ValueError("twist must be between 0 and 359") - if not (0.0 <= self.altitude_angle <= math.pi / 2): - raise ValueError("altitude_angle must be between 0.0 and π/2") - if not (0.0 <= self.azimuth_angle <= 2 * math.pi): - raise ValueError("azimuth_angle must be between 0.0 and 2π") - - def to_dict(self) -> dict: - """Convert the PointerCommonProperties to a dictionary.""" - result: dict[str, Any] = {} - if self.width != 1: - result["width"] = self.width - if self.height != 1: - result["height"] = self.height - if self.pressure != 0.0: - result["pressure"] = self.pressure - if self.tangential_pressure != 0.0: - result["tangentialPressure"] = self.tangential_pressure - if self.twist != 0: - result["twist"] = self.twist - if self.altitude_angle != 0.0: - result["altitudeAngle"] = self.altitude_angle - if self.azimuth_angle != 0.0: - result["azimuthAngle"] = self.azimuth_angle - return result - +class NoneSourceActions: + """NoneSourceActions.""" -# Action classes -@dataclass -class PauseAction: - """Represents a pause action.""" + type: str = field(default="none", init=False) + id: str | None = None + actions: list[Any] = field(default_factory=list) - duration: int | None = None - @property - def type(self) -> str: - return "pause" +@dataclass +class KeySourceActions: + """KeySourceActions.""" - def to_dict(self) -> dict: - """Convert the PauseAction to a dictionary.""" - result: dict[str, Any] = {"type": self.type} - if self.duration is not None: - result["duration"] = self.duration - return result + type: str = field(default="key", init=False) + id: str | None = None + actions: list[Any] = field(default_factory=list) @dataclass -class KeyDownAction: - """Represents a key down action.""" - - value: str = "" - - @property - def type(self) -> str: - return "keyDown" +class PointerSourceActions: + """PointerSourceActions.""" - def to_dict(self) -> dict: - """Convert the KeyDownAction to a dictionary.""" - return {"type": self.type, "value": self.value} + type: str = field(default="pointer", init=False) + id: str | None = None + parameters: Any | None = None + actions: list[Any] = field(default_factory=list) @dataclass -class KeyUpAction: - """Represents a key up action.""" - - value: str = "" - - @property - def type(self) -> str: - return "keyUp" +class PointerParameters: + """PointerParameters.""" - def to_dict(self) -> dict: - """Convert the KeyUpAction to a dictionary.""" - return {"type": self.type, "value": self.value} + pointer_type: Any | None = None @dataclass -class PointerDownAction: - """Represents a pointer down action.""" +class WheelSourceActions: + """WheelSourceActions.""" - button: int = 0 - properties: PointerCommonProperties | None = None + type: str = field(default="wheel", init=False) + id: str | None = None + actions: list[Any] = field(default_factory=list) - @property - def type(self) -> str: - return "pointerDown" - def to_dict(self) -> dict: - """Convert the PointerDownAction to a dictionary.""" - result: dict[str, Any] = {"type": self.type, "button": self.button} - if self.properties: - result.update(self.properties.to_dict()) - return result +@dataclass +class PauseAction: + """PauseAction.""" + + type: str = field(default="pause", init=False) + duration: Any | None = None @dataclass -class PointerUpAction: - """Represents a pointer up action.""" +class KeyDownAction: + """KeyDownAction.""" - button: int = 0 + type: str = field(default="keyDown", init=False) + value: str | None = None - @property - def type(self) -> str: - return "pointerUp" - def to_dict(self) -> dict: - """Convert the PointerUpAction to a dictionary.""" - return {"type": self.type, "button": self.button} +@dataclass +class KeyUpAction: + """KeyUpAction.""" + + type: str = field(default="keyUp", init=False) + value: str | None = None @dataclass -class PointerMoveAction: - """Represents a pointer move action.""" - - x: float = 0 - y: float = 0 - duration: int | None = None - origin: str | ElementOrigin | None = None - properties: PointerCommonProperties | None = None - - @property - def type(self) -> str: - return "pointerMove" - - def to_dict(self) -> dict: - """Convert the PointerMoveAction to a dictionary.""" - result: dict[str, Any] = {"type": self.type, "x": self.x, "y": self.y} - if self.duration is not None: - result["duration"] = self.duration - if self.origin is not None: - if isinstance(self.origin, ElementOrigin): - result["origin"] = self.origin.to_dict() - else: - result["origin"] = self.origin - if self.properties: - result.update(self.properties.to_dict()) - return result +class PointerUpAction: + """PointerUpAction.""" + + type: str = field(default="pointerUp", init=False) + button: Any | None = None @dataclass class WheelScrollAction: - """Represents a wheel scroll action.""" - - x: int = 0 - y: int = 0 - delta_x: int = 0 - delta_y: int = 0 - duration: int | None = None - origin: str | ElementOrigin | None = Origin.VIEWPORT - - @property - def type(self) -> str: - return "scroll" - - def to_dict(self) -> dict: - """Convert the WheelScrollAction to a dictionary.""" - result: dict[str, Any] = { - "type": self.type, - "x": self.x, - "y": self.y, - "deltaX": self.delta_x, - "deltaY": self.delta_y, - } - if self.duration is not None: - result["duration"] = self.duration - if self.origin is not None: - if isinstance(self.origin, ElementOrigin): - result["origin"] = self.origin.to_dict() - else: - result["origin"] = self.origin - return result - + """WheelScrollAction.""" -# Source Actions -@dataclass -class NoneSourceActions: - """Represents a sequence of none actions.""" + type: str = field(default="scroll", init=False) + x: Any | None = None + y: Any | None = None + delta_x: Any | None = None + delta_y: Any | None = None + duration: Any | None = None + origin: Any | None = None - id: str = "" - actions: list[PauseAction] = field(default_factory=list) - @property - def type(self) -> str: - return "none" +@dataclass +class PointerCommonProperties: + """PointerCommonProperties.""" - def to_dict(self) -> dict: - """Convert the NoneSourceActions to a dictionary.""" - return {"type": self.type, "id": self.id, "actions": [action.to_dict() for action in self.actions]} + width: Any | None = None + height: Any | None = None + pressure: Any | None = None + tangential_pressure: Any | None = None + twist: Any | None = None + altitude_angle: Any | None = None + azimuth_angle: Any | None = None @dataclass -class KeySourceActions: - """Represents a sequence of key actions.""" +class ReleaseActionsParameters: + """ReleaseActionsParameters.""" - id: str = "" - actions: list[PauseAction | KeyDownAction | KeyUpAction] = field(default_factory=list) + context: Any | None = None - @property - def type(self) -> str: - return "key" - def to_dict(self) -> dict: - """Convert the KeySourceActions to a dictionary.""" - return {"type": self.type, "id": self.id, "actions": [action.to_dict() for action in self.actions]} +@dataclass +class SetFilesParameters: + """SetFilesParameters.""" + + context: Any | None = None + element: Any | None = None + files: list[Any] = field(default_factory=list) @dataclass -class PointerSourceActions: - """Represents a sequence of pointer actions.""" - - id: str = "" - parameters: PointerParameters | None = None - actions: list[PauseAction | PointerDownAction | PointerUpAction | PointerMoveAction] = field(default_factory=list) - - def __post_init__(self): - if self.parameters is None: - self.parameters = PointerParameters() - - @property - def type(self) -> str: - return "pointer" - - def to_dict(self) -> dict: - """Convert the PointerSourceActions to a dictionary.""" - result: dict[str, Any] = { - "type": self.type, - "id": self.id, - "actions": [action.to_dict() for action in self.actions], - } - if self.parameters: - result["parameters"] = self.parameters.to_dict() - return result +class FileDialogInfo: + """FileDialogInfo - parameters for the input.fileDialogOpened event.""" + context: Any | None = None + element: Any | None = None + multiple: bool | None = None + + @classmethod + def from_json(cls, params: dict) -> "FileDialogInfo": + """Deserialize event params into FileDialogInfo.""" + return cls( + context=params.get("context"), + element=params.get("element"), + multiple=params.get("multiple"), + ) @dataclass -class WheelSourceActions: - """Represents a sequence of wheel actions.""" +class PointerMoveAction: + """PointerMoveAction.""" - id: str = "" - actions: list[PauseAction | WheelScrollAction] = field(default_factory=list) + type: str = field(default="pointerMove", init=False) + x: Any | None = None + y: Any | None = None + duration: Any | None = None + origin: Any | None = None + properties: Any | None = None - @property - def type(self) -> str: - return "wheel" +@dataclass +class PointerDownAction: + """PointerDownAction.""" - def to_dict(self) -> dict: - """Convert the WheelSourceActions to a dictionary.""" - return {"type": self.type, "id": self.id, "actions": [action.to_dict() for action in self.actions]} + type: str = field(default="pointerDown", init=False) + button: Any | None = None + properties: Any | None = None +# BiDi Event Name to Parameter Type Mapping +EVENT_NAME_MAPPING = { + "file_dialog_opened": "input.fileDialogOpened", +} @dataclass -class FileDialogInfo: - """Represents file dialog information from input.fileDialogOpened event.""" +class EventConfig: + """Configuration for a BiDi event.""" + event_key: str + bidi_event: str + event_class: type - context: str - multiple: bool - element: dict | None = None - @classmethod - def from_dict(cls, data: dict) -> "FileDialogInfo": - """Creates a FileDialogInfo instance from a dictionary. +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. Args: - data: A dictionary containing the file dialog information. + params: Raw BiDi event params with camelCase keys. Returns: - FileDialogInfo: A new instance of FileDialogInfo. + An instance of the dataclass, or the raw dict on failure. """ - return cls(context=data["context"], multiple=data["multiple"], element=data.get("element")) + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) -# Event Class -class FileDialogOpened: - """Event class for input.fileDialogOpened event.""" +class _EventManager: + """Manages event subscriptions and callbacks.""" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + """Subscribe to a BiDi event if not already subscribed.""" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist.""" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: + session = Session(self.conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event]["callbacks"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() - event_class = "input.fileDialogOpened" - @classmethod - def from_json(cls, json): - """Create FileDialogInfo from JSON data.""" - return FileDialogInfo.from_dict(json) class Input: - """BiDi implementation of the input module.""" + """WebDriver BiDi input module.""" + + EVENT_CONFIGS: dict[str, EventConfig] = {} + def __init__(self, conn) -> None: + self._conn = conn + self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) + + def perform_actions(self, context: Any | None = None, actions: list[Any] | None = None): + """Execute input.performActions.""" + if context is None: + raise TypeError("perform_actions() missing required argument: {{snake_param!r}}") + if actions is None: + raise TypeError("perform_actions() missing required argument: {{snake_param!r}}") + + params = { + "context": context, + "actions": actions, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("input.performActions", params) + result = self._conn.execute(cmd) + return result - def __init__(self, conn): - self.conn = conn - self.subscriptions = {} - self.callbacks = {} + def release_actions(self, context: Any | None = None): + """Execute input.releaseActions.""" + if context is None: + raise TypeError("release_actions() missing required argument: {{snake_param!r}}") - def perform_actions( - self, - context: str, - actions: list[NoneSourceActions | KeySourceActions | PointerSourceActions | WheelSourceActions], - ) -> None: - """Performs a sequence of user input actions. + params = { + "context": context, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("input.releaseActions", params) + result = self._conn.execute(cmd) + return result - Args: - context: The browsing context ID where actions should be performed. - actions: A list of source actions to perform. - """ - params = {"context": context, "actions": [action.to_dict() for action in actions]} - self.conn.execute(command_builder("input.performActions", params)) + def set_files(self, context: Any | None = None, element: Any | None = None, files: list[Any] | None = None): + """Execute input.setFiles.""" + if context is None: + raise TypeError("set_files() missing required argument: {{snake_param!r}}") + if element is None: + raise TypeError("set_files() missing required argument: {{snake_param!r}}") + if files is None: + raise TypeError("set_files() missing required argument: {{snake_param!r}}") + + params = { + "context": context, + "element": element, + "files": files, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("input.setFiles", params) + result = self._conn.execute(cmd) + return result - def release_actions(self, context: str) -> None: - """Releases all input state for the given context. + def add_file_dialog_handler(self, callback) -> int: + """Subscribe to the input.fileDialogOpened event. Args: - context: The browsing context ID to release actions for. + callback: Callable invoked with a FileDialogInfo when a file dialog opens. + + Returns: + A handler ID that can be passed to remove_file_dialog_handler. """ - params = {"context": context} - self.conn.execute(command_builder("input.releaseActions", params)) + return self._event_manager.add_event_handler("file_dialog_opened", callback) - def set_files(self, context: str, element: dict, files: list[str]) -> None: - """Sets files for a file input element. + def remove_file_dialog_handler(self, handler_id: int) -> None: + """Unsubscribe a previously registered file dialog event handler. Args: - context: The browsing context ID. - element: The element reference (script.SharedReference). - files: A list of file paths to set. + handler_id: The handler ID returned by add_file_dialog_handler. """ - params = {"context": context, "element": element, "files": files} - self.conn.execute(command_builder("input.setFiles", params)) + return self._event_manager.remove_event_handler("file_dialog_opened", handler_id) - def add_file_dialog_handler(self, handler) -> int: - """Add a handler for file dialog opened events. + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + """Add an event handler. Args: - handler: Callback function that takes a FileDialogInfo object. + event: The event to subscribe to. + callback: The callback function to execute on event. + contexts: The context IDs to subscribe to (optional). Returns: - int: Callback ID for removing the handler later. + The callback ID. """ - # Subscribe to the event if not already subscribed - if FileDialogOpened.event_class not in self.subscriptions: - session = Session(self.conn) - self.conn.execute(session.subscribe(FileDialogOpened.event_class)) - self.subscriptions[FileDialogOpened.event_class] = [] + return self._event_manager.add_event_handler(event, callback, contexts) - # Add callback - the callback receives the parsed FileDialogInfo directly - callback_id = self.conn.add_callback(FileDialogOpened, handler) - - self.subscriptions[FileDialogOpened.event_class].append(callback_id) - self.callbacks[callback_id] = handler - - return callback_id - - def remove_file_dialog_handler(self, callback_id: int) -> None: - """Remove a file dialog handler. + def remove_event_handler(self, event: str, callback_id: int) -> None: + """Remove an event handler. Args: - callback_id: The callback ID returned by add_file_dialog_handler. + event: The event to unsubscribe from. + callback_id: The callback ID. """ - if callback_id in self.callbacks: - del self.callbacks[callback_id] - - if FileDialogOpened.event_class in self.subscriptions: - if callback_id in self.subscriptions[FileDialogOpened.event_class]: - self.subscriptions[FileDialogOpened.event_class].remove(callback_id) - - # If no more callbacks for this event, unsubscribe - if not self.subscriptions[FileDialogOpened.event_class]: - session = Session(self.conn) - self.conn.execute(session.unsubscribe(FileDialogOpened.event_class)) - del self.subscriptions[FileDialogOpened.event_class] - - self.conn.remove_callback(FileDialogOpened, callback_id) + return self._event_manager.remove_event_handler(event, callback_id) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + return self._event_manager.clear_event_handlers() + +# Event Info Type Aliases +# Event: input.fileDialogOpened +FileDialogOpened = globals().get('FileDialogInfo', dict) # Fallback to dict if type not defined + + +# Populate EVENT_CONFIGS with event configuration mappings +_globals = globals() +Input.EVENT_CONFIGS = { + "file_dialog_opened": EventConfig( + "file_dialog_opened", + "input.fileDialogOpened", + _globals.get("FileDialogOpened", dict) if _globals.get("FileDialogOpened") else dict, + ), +} diff --git a/py/selenium/webdriver/common/bidi/log.py b/py/selenium/webdriver/common/bidi/log.py index 575545776bda8..3c6a95d74f6d1 100644 --- a/py/selenium/webdriver/common/bidi/log.py +++ b/py/selenium/webdriver/common/bidi/log.py @@ -1,81 +1,303 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# WebDriver BiDi module: log from __future__ import annotations -from dataclasses import dataclass from typing import Any +from dataclasses import dataclass +import threading +from collections.abc import Callable +from selenium.webdriver.common.bidi.session import Session -class LogEntryAdded: - event_class = "log.entryAdded" +class Level: + """Level.""" + + DEBUG = "debug" + INFO = "info" + WARN = "warn" + ERROR = "error" - @classmethod - def from_json(cls, json: dict[str, Any]) -> ConsoleLogEntry | JavaScriptLogEntry | None: - if json["type"] == "console": - return ConsoleLogEntry.from_json(json) - elif json["type"] == "javascript": - return JavaScriptLogEntry.from_json(json) - return None + +LogLevel = Level + +@dataclass +class BaseLogEntry: + """BaseLogEntry.""" + + level: Any | None = None + source: Any | None = None + text: Any | None = None + timestamp: Any | None = None + stack_trace: Any | None = None + + +@dataclass +class GenericLogEntry: + """GenericLogEntry.""" + + type: str | None = None @dataclass class ConsoleLogEntry: - level: str - text: str - timestamp: str - method: str - args: list[dict[str, Any]] - type_: str + """ConsoleLogEntry - a console log entry from the browser.""" + + type_: str | None = None + method: str | None = None + args: list | None = None + level: Any | None = None + text: Any | None = None + source: Any | None = None + timestamp: Any | None = None + stack_trace: Any | None = None @classmethod - def from_json(cls, json: dict[str, Any]) -> ConsoleLogEntry: + def from_json(cls, params: dict) -> "ConsoleLogEntry": + """Deserialize from BiDi params dict.""" return cls( - level=json["level"], - text=json["text"], - timestamp=json["timestamp"], - method=json["method"], - args=json["args"], - type_=json["type"], + type_=params.get("type"), + method=params.get("method"), + args=params.get("args"), + level=params.get("level"), + text=params.get("text"), + source=params.get("source"), + timestamp=params.get("timestamp"), + stack_trace=params.get("stackTrace"), ) - @dataclass -class JavaScriptLogEntry: - level: str - text: str - timestamp: str - stacktrace: dict[str, Any] - type_: str +class JavascriptLogEntry: + """JavascriptLogEntry - a JavaScript error log entry from the browser.""" + + type_: str | None = None + level: Any | None = None + text: Any | None = None + source: Any | None = None + timestamp: Any | None = None + stacktrace: Any | None = None @classmethod - def from_json(cls, json: dict[str, Any]) -> JavaScriptLogEntry: + def from_json(cls, params: dict) -> "JavascriptLogEntry": + """Deserialize from BiDi params dict.""" return cls( - level=json["level"], - text=json["text"], - timestamp=json["timestamp"], - stacktrace=json["stackTrace"], - type_=json["type"], + type_=params.get("type"), + level=params.get("level"), + text=params.get("text"), + source=params.get("source"), + timestamp=params.get("timestamp"), + stacktrace=params.get("stackTrace"), ) +Entry = GenericLogEntry | ConsoleLogEntry | JavascriptLogEntry -class LogLevel: - """Represents log level.""" +# BiDi Event Name to Parameter Type Mapping +EVENT_NAME_MAPPING = { + "entry_added": "log.entryAdded", +} - DEBUG = "debug" - INFO = "info" - WARN = "warn" - ERROR = "error" +@dataclass +class EventConfig: + """Configuration for a BiDi event.""" + event_key: str + bidi_event: str + event_class: type + + +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. + + Args: + params: Raw BiDi event params with camelCase keys. + + Returns: + An instance of the dataclass, or the raw dict on failure. + """ + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) + + +class _EventManager: + """Manages event subscriptions and callbacks.""" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + """Subscribe to a BiDi event if not already subscribed.""" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist.""" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: + session = Session(self.conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event]["callbacks"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() + + + + +class Log: + """WebDriver BiDi log module.""" + + EVENT_CONFIGS: dict[str, EventConfig] = {} + def __init__(self, conn) -> None: + self._conn = conn + self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) + + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + """Add an event handler. + + Args: + event: The event to subscribe to. + callback: The callback function to execute on event. + contexts: The context IDs to subscribe to (optional). + + Returns: + The callback ID. + """ + return self._event_manager.add_event_handler(event, callback, contexts) + + def remove_event_handler(self, event: str, callback_id: int) -> None: + """Remove an event handler. + + Args: + event: The event to unsubscribe from. + callback_id: The callback ID. + """ + return self._event_manager.remove_event_handler(event, callback_id) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + return self._event_manager.clear_event_handlers() + +# Event Info Type Aliases +# Event: log.entryAdded +EntryAdded = Entry + + +# Populate EVENT_CONFIGS with event configuration mappings +_globals = globals() +Log.EVENT_CONFIGS = { + "entry_added": EventConfig( + "entry_added", + "log.entryAdded", + _globals.get("EntryAdded", dict) if _globals.get("EntryAdded") else dict, + ), +} diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py index 82472838dccde..6a0edf0b2b5e7 100644 --- a/py/selenium/webdriver/common/bidi/network.py +++ b/py/selenium/webdriver/common/bidi/network.py @@ -1,338 +1,1063 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - +# WebDriver BiDi module: network from __future__ import annotations -from collections.abc import Callable from typing import Any +from .common import command_builder +from dataclasses import dataclass +from dataclasses import field +import threading +from collections.abc import Callable +from selenium.webdriver.common.bidi.session import Session -from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.remote.websocket_connection import WebSocketConnection +class SameSite: + """SameSite.""" -class NetworkEvent: - """Represents a network event.""" + STRICT = "strict" + LAX = "lax" + NONE = "none" + DEFAULT = "default" - def __init__(self, event_class: str, **kwargs: Any) -> None: - self.event_class = event_class - self.params = kwargs - @classmethod - def from_json(cls, json: dict[str, Any]) -> NetworkEvent: - return cls(event_class=json.get("event_class", ""), **json) +class DataType: + """DataType.""" + REQUEST = "request" + RESPONSE = "response" -class Network: - EVENTS = { - "before_request": "network.beforeRequestSent", - "response_started": "network.responseStarted", - "response_completed": "network.responseCompleted", - "auth_required": "network.authRequired", - "fetch_error": "network.fetchError", - "continue_request": "network.continueRequest", - "continue_auth": "network.continueWithAuth", - } - - PHASES = { - "before_request": "beforeRequestSent", - "response_started": "responseStarted", - "auth_required": "authRequired", - } - - def __init__(self, conn: WebSocketConnection) -> None: - self.conn = conn - self.intercepts: list[str] = [] - self.callbacks: dict[str | int, Any] = {} - self.subscriptions: dict[str, list[int]] = {} - def _add_intercept( - self, - phases: list[str] | None = None, - contexts: list[str] | None = None, - url_patterns: list[Any] | None = None, - ) -> dict[str, Any]: - """Add an intercept to the network. +class InterceptPhase: + """InterceptPhase.""" + + BEFOREREQUESTSENT = "beforeRequestSent" + RESPONSESTARTED = "responseStarted" + AUTHREQUIRED = "authRequired" + + +class ContinueWithAuthNoCredentials: + """ContinueWithAuthNoCredentials.""" + + DEFAULT = "default" + CANCEL = "cancel" + + +@dataclass +class AuthChallenge: + """AuthChallenge.""" + + scheme: str | None = None + realm: str | None = None + + +@dataclass +class AuthCredentials: + """AuthCredentials.""" + + type: str = field(default="password", init=False) + username: str | None = None + password: str | None = None + + +@dataclass +class BaseParameters: + """BaseParameters.""" + + context: Any | None = None + is_blocked: bool | None = None + navigation: Any | None = None + redirect_count: Any | None = None + request: Any | None = None + timestamp: Any | None = None + intercepts: list[Any] = field(default_factory=list) + + +@dataclass +class StringValue: + """StringValue.""" + + type: str = field(default="string", init=False) + value: str | None = None + + +@dataclass +class Base64Value: + """Base64Value.""" + + type: str = field(default="base64", init=False) + value: str | None = None + + +@dataclass +class Cookie: + """Cookie.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + path: str | None = None + size: Any | None = None + http_only: bool | None = None + secure: bool | None = None + same_site: Any | None = None + expiry: Any | None = None + + +@dataclass +class CookieHeader: + """CookieHeader.""" + + name: str | None = None + value: Any | None = None + + +@dataclass +class FetchTimingInfo: + """FetchTimingInfo.""" + + time_origin: Any | None = None + request_time: Any | None = None + redirect_start: Any | None = None + redirect_end: Any | None = None + fetch_start: Any | None = None + dns_start: Any | None = None + dns_end: Any | None = None + connect_start: Any | None = None + connect_end: Any | None = None + tls_start: Any | None = None + request_start: Any | None = None + response_start: Any | None = None + response_end: Any | None = None + + +@dataclass +class Header: + """Header.""" + + name: str | None = None + value: Any | None = None + + +@dataclass +class Initiator: + """Initiator.""" + + column_number: Any | None = None + line_number: Any | None = None + request: Any | None = None + stack_trace: Any | None = None + type: Any | None = None + + +@dataclass +class ResponseContent: + """ResponseContent.""" + + size: Any | None = None + + +@dataclass +class ResponseData: + """ResponseData.""" + + url: str | None = None + protocol: str | None = None + status: Any | None = None + status_text: str | None = None + from_cache: bool | None = None + headers: list[Any] = field(default_factory=list) + mime_type: str | None = None + bytes_received: Any | None = None + headers_size: Any | None = None + body_size: Any | None = None + content: Any | None = None + auth_challenges: list[Any] = field(default_factory=list) + + +@dataclass +class SetCookieHeader: + """SetCookieHeader.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + http_only: bool | None = None + expiry: str | None = None + max_age: Any | None = None + path: str | None = None + same_site: Any | None = None + secure: bool | None = None + + +@dataclass +class UrlPatternPattern: + """UrlPatternPattern.""" + + type: str = field(default="pattern", init=False) + protocol: str | None = None + hostname: str | None = None + port: str | None = None + pathname: str | None = None + search: str | None = None + + +@dataclass +class UrlPatternString: + """UrlPatternString.""" + + type: str = field(default="string", init=False) + pattern: str | None = None + + +@dataclass +class AddDataCollectorParameters: + """AddDataCollectorParameters.""" + + data_types: list[Any] = field(default_factory=list) + max_encoded_data_size: Any | None = None + collector_type: Any | None = None + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) + + +@dataclass +class AddDataCollectorResult: + """AddDataCollectorResult.""" + + collector: Any | None = None + + +@dataclass +class AddInterceptParameters: + """AddInterceptParameters.""" + + phases: list[Any] = field(default_factory=list) + contexts: list[Any] = field(default_factory=list) + url_patterns: list[Any] = field(default_factory=list) + + +@dataclass +class AddInterceptResult: + """AddInterceptResult.""" + + intercept: Any | None = None + + +@dataclass +class ContinueResponseParameters: + """ContinueResponseParameters.""" + + request: Any | None = None + cookies: list[Any] = field(default_factory=list) + credentials: Any | None = None + headers: list[Any] = field(default_factory=list) + reason_phrase: str | None = None + status_code: Any | None = None + + +@dataclass +class ContinueWithAuthParameters: + """ContinueWithAuthParameters.""" + + request: Any | None = None + + +@dataclass +class ContinueWithAuthCredentials: + """ContinueWithAuthCredentials.""" + + action: str = field(default="provideCredentials", init=False) + credentials: Any | None = None + + +@dataclass +class disownDataParameters: + """disownDataParameters.""" + + data_type: Any | None = None + collector: Any | None = None + request: Any | None = None + + +@dataclass +class FailRequestParameters: + """FailRequestParameters.""" + + request: Any | None = None + + +@dataclass +class GetDataParameters: + """GetDataParameters.""" + + data_type: Any | None = None + collector: Any | None = None + disown: bool | None = None + request: Any | None = None + + +@dataclass +class GetDataResult: + """GetDataResult.""" + + bytes: Any | None = None + + +@dataclass +class ProvideResponseParameters: + """ProvideResponseParameters.""" + + request: Any | None = None + body: Any | None = None + cookies: list[Any] = field(default_factory=list) + headers: list[Any] = field(default_factory=list) + reason_phrase: str | None = None + status_code: Any | None = None + + +@dataclass +class RemoveDataCollectorParameters: + """RemoveDataCollectorParameters.""" + + collector: Any | None = None + + +@dataclass +class RemoveInterceptParameters: + """RemoveInterceptParameters.""" + + intercept: Any | None = None + + +@dataclass +class SetCacheBehaviorParameters: + """SetCacheBehaviorParameters.""" + + cache_behavior: Any | None = None + contexts: list[Any] = field(default_factory=list) + + +@dataclass +class SetExtraHeadersParameters: + """SetExtraHeadersParameters.""" + + headers: list[Any] = field(default_factory=list) + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) + + +@dataclass +class ResponseStartedParameters: + """ResponseStartedParameters.""" + + response: Any | None = None + + +class BytesValue: + """A string or base64-encoded bytes value used in cookie operations. + + This corresponds to network.BytesValue in the WebDriver BiDi specification, + wrapping either a plain string or a base64-encoded binary value. + """ + + TYPE_STRING = "string" + TYPE_BASE64 = "base64" + + def __init__(self, type: Any | None, value: Any | None) -> None: + self.type = type + self.value = value + + def to_bidi_dict(self) -> dict: + return {"type": self.type, "value": self.value} + +class Request: + """Wraps a BiDi network request event params and provides request action methods.""" + + def __init__(self, conn, params): + self._conn = conn + self._params = params if isinstance(params, dict) else {} + req = self._params.get("request", {}) or {} + self.url = req.get("url", "") + self._request_id = req.get("request") + + def continue_request(self, **kwargs): + """Continue the intercepted request.""" + from selenium.webdriver.common.bidi.common import command_builder as _cb + + params = {"request": self._request_id} + params.update(kwargs) + self._conn.execute(_cb("network.continueRequest", params)) + +# BiDi Event Name to Parameter Type Mapping +EVENT_NAME_MAPPING = { + "auth_required": "network.authRequired", + "before_request": "network.beforeRequestSent", +} + +@dataclass +class EventConfig: + """Configuration for a BiDi event.""" + event_key: str + bidi_event: str + event_class: type + + +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. Args: - phases: A list of phases to intercept. Default is None (empty list). - contexts: A list of contexts to intercept. Default is None. - url_patterns: A list of URL patterns to intercept. Default is None. + params: Raw BiDi event params with camelCase keys. Returns: - str: intercept id + An instance of the dataclass, or the raw dict on failure. """ + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) + + +class _EventManager: + """Manages event subscriptions and callbacks.""" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + """Subscribe to a BiDi event if not already subscribed.""" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist.""" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: + session = Session(self.conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event]["callbacks"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() + + + + +class Network: + """WebDriver BiDi network module.""" + + EVENT_CONFIGS: dict[str, EventConfig] = {} + def __init__(self, conn) -> None: + self._conn = conn + self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) + self.intercepts: list[Any] = [] + self._handler_intercepts: dict[str, Any] = {} + + def add_data_collector( + self, + data_types: list[Any] | None = None, + max_encoded_data_size: Any | None = None, + collector_type: Any | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute network.addDataCollector.""" + if data_types is None: + raise TypeError("add_data_collector() missing required argument: {{snake_param!r}}") + if max_encoded_data_size is None: + raise TypeError("add_data_collector() missing required argument: {{snake_param!r}}") + + params = { + "dataTypes": data_types, + "maxEncodedDataSize": max_encoded_data_size, + "collectorType": collector_type, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.addDataCollector", params) + result = self._conn.execute(cmd) + return result + + def add_intercept( + self, + phases: list[Any] | None = None, + contexts: list[Any] | None = None, + url_patterns: list[Any] | None = None, + ): + """Execute network.addIntercept.""" if phases is None: - phases = [] - params = {} - if contexts is not None: - params["contexts"] = contexts - if url_patterns is not None: - params["urlPatterns"] = url_patterns - if len(phases) > 0: - params["phases"] = phases - else: - params["phases"] = ["beforeRequestSent"] + raise TypeError("add_intercept() missing required argument: {{snake_param!r}}") + + params = { + "phases": phases, + "contexts": contexts, + "urlPatterns": url_patterns, + } + params = {k: v for k, v in params.items() if v is not None} cmd = command_builder("network.addIntercept", params) + result = self._conn.execute(cmd) + return result - result: dict[str, Any] = self.conn.execute(cmd) - self.intercepts.append(result["intercept"]) + def continue_request( + self, + request: Any | None = None, + body: Any | None = None, + cookies: list[Any] | None = None, + headers: list[Any] | None = None, + method: Any | None = None, + url: Any | None = None, + ): + """Execute network.continueRequest.""" + if request is None: + raise TypeError("continue_request() missing required argument: {{snake_param!r}}") + + params = { + "request": request, + "body": body, + "cookies": cookies, + "headers": headers, + "method": method, + "url": url, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.continueRequest", params) + result = self._conn.execute(cmd) return result - def _remove_intercept(self, intercept: str | None = None) -> None: - """Remove a specific intercept, or all intercepts. + def continue_response( + self, + request: Any | None = None, + cookies: list[Any] | None = None, + credentials: Any | None = None, + headers: list[Any] | None = None, + reason_phrase: Any | None = None, + status_code: Any | None = None, + ): + """Execute network.continueResponse.""" + if request is None: + raise TypeError("continue_response() missing required argument: {{snake_param!r}}") - Args: - intercept: The intercept to remove. Default is None. + params = { + "request": request, + "cookies": cookies, + "credentials": credentials, + "headers": headers, + "reasonPhrase": reason_phrase, + "statusCode": status_code, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.continueResponse", params) + result = self._conn.execute(cmd) + return result - Raises: - ValueError: If intercept is not found. + def continue_with_auth(self, request: Any | None = None): + """Execute network.continueWithAuth.""" + if request is None: + raise TypeError("continue_with_auth() missing required argument: {{snake_param!r}}") - Note: - If intercept is None, all intercepts will be removed. - """ + params = { + "request": request, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.continueWithAuth", params) + result = self._conn.execute(cmd) + return result + + def disown_data(self, data_type: Any | None = None, collector: Any | None = None, request: Any | None = None): + """Execute network.disownData.""" + if data_type is None: + raise TypeError("disown_data() missing required argument: {{snake_param!r}}") + if collector is None: + raise TypeError("disown_data() missing required argument: {{snake_param!r}}") + if request is None: + raise TypeError("disown_data() missing required argument: {{snake_param!r}}") + + params = { + "dataType": data_type, + "collector": collector, + "request": request, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.disownData", params) + result = self._conn.execute(cmd) + return result + + def fail_request(self, request: Any | None = None): + """Execute network.failRequest.""" + if request is None: + raise TypeError("fail_request() missing required argument: {{snake_param!r}}") + + params = { + "request": request, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.failRequest", params) + result = self._conn.execute(cmd) + return result + + def get_data( + self, + data_type: Any | None = None, + collector: Any | None = None, + disown: bool | None = None, + request: Any | None = None, + ): + """Execute network.getData.""" + if data_type is None: + raise TypeError("get_data() missing required argument: {{snake_param!r}}") + if request is None: + raise TypeError("get_data() missing required argument: {{snake_param!r}}") + + params = { + "dataType": data_type, + "collector": collector, + "disown": disown, + "request": request, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.getData", params) + result = self._conn.execute(cmd) + return result + + def provide_response( + self, + request: Any | None = None, + body: Any | None = None, + cookies: list[Any] | None = None, + headers: list[Any] | None = None, + reason_phrase: Any | None = None, + status_code: Any | None = None, + ): + """Execute network.provideResponse.""" + if request is None: + raise TypeError("provide_response() missing required argument: {{snake_param!r}}") + + params = { + "request": request, + "body": body, + "cookies": cookies, + "headers": headers, + "reasonPhrase": reason_phrase, + "statusCode": status_code, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.provideResponse", params) + result = self._conn.execute(cmd) + return result + + def remove_data_collector(self, collector: Any | None = None): + """Execute network.removeDataCollector.""" + if collector is None: + raise TypeError("remove_data_collector() missing required argument: {{snake_param!r}}") + + params = { + "collector": collector, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.removeDataCollector", params) + result = self._conn.execute(cmd) + return result + + def remove_intercept(self, intercept: Any | None = None): + """Execute network.removeIntercept.""" if intercept is None: - intercepts_to_remove = self.intercepts.copy() # create a copy before iterating - for intercept_id in intercepts_to_remove: # remove all intercepts - self.conn.execute(command_builder("network.removeIntercept", {"intercept": intercept_id})) - self.intercepts.remove(intercept_id) - else: - try: - self.conn.execute(command_builder("network.removeIntercept", {"intercept": intercept})) - self.intercepts.remove(intercept) - except Exception as e: - raise Exception(f"Exception: {e}") - - def _on_request(self, event_name: str, callback: Callable[[Request], Any]) -> int: - """Set a callback function to subscribe to a network event. + raise TypeError("remove_intercept() missing required argument: {{snake_param!r}}") + + params = { + "intercept": intercept, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.removeIntercept", params) + result = self._conn.execute(cmd) + return result + + def set_cache_behavior(self, cache_behavior: Any | None = None, contexts: list[Any] | None = None): + """Execute network.setCacheBehavior.""" + if cache_behavior is None: + raise TypeError("set_cache_behavior() missing required argument: {{snake_param!r}}") + + params = { + "cacheBehavior": cache_behavior, + "contexts": contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.setCacheBehavior", params) + result = self._conn.execute(cmd) + return result + + def set_extra_headers( + self, + headers: list[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute network.setExtraHeaders.""" + if headers is None: + raise TypeError("set_extra_headers() missing required argument: {{snake_param!r}}") + + params = { + "headers": headers, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.setExtraHeaders", params) + result = self._conn.execute(cmd) + return result + + def before_request_sent(self, initiator: Any | None = None, method: Any | None = None, params: Any | None = None): + """Execute network.beforeRequestSent.""" + if method is None: + raise TypeError("before_request_sent() missing required argument: {{snake_param!r}}") + if params is None: + raise TypeError("before_request_sent() missing required argument: {{snake_param!r}}") + + params = { + "initiator": initiator, + "method": method, + "params": params, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.beforeRequestSent", params) + result = self._conn.execute(cmd) + return result + + def fetch_error(self, error_text: Any | None = None, method: Any | None = None, params: Any | None = None): + """Execute network.fetchError.""" + if error_text is None: + raise TypeError("fetch_error() missing required argument: {{snake_param!r}}") + if method is None: + raise TypeError("fetch_error() missing required argument: {{snake_param!r}}") + if params is None: + raise TypeError("fetch_error() missing required argument: {{snake_param!r}}") + + params = { + "errorText": error_text, + "method": method, + "params": params, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.fetchError", params) + result = self._conn.execute(cmd) + return result + + def response_completed(self, response: Any | None = None, method: Any | None = None, params: Any | None = None): + """Execute network.responseCompleted.""" + if response is None: + raise TypeError("response_completed() missing required argument: {{snake_param!r}}") + if method is None: + raise TypeError("response_completed() missing required argument: {{snake_param!r}}") + if params is None: + raise TypeError("response_completed() missing required argument: {{snake_param!r}}") + + params = { + "response": response, + "method": method, + "params": params, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.responseCompleted", params) + result = self._conn.execute(cmd) + return result + + def response_started(self, response: Any | None = None): + """Execute network.responseStarted.""" + if response is None: + raise TypeError("response_started() missing required argument: {{snake_param!r}}") + + params = { + "response": response, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("network.responseStarted", params) + result = self._conn.execute(cmd) + return result + + def _add_intercept(self, phases=None, url_patterns=None): + """Add a low-level network intercept. Args: - event_name: The event to subscribe to. - callback: The callback function to execute on event. - Takes Request object as argument. + phases: list of intercept phases (default: ["beforeRequestSent"]) + url_patterns: optional URL patterns to filter Returns: - int: callback id + dict with "intercept" key containing the intercept ID """ - event = NetworkEvent(event_name) - - def _callback(event_data: NetworkEvent) -> None: - request = Request( - network=self, - request_id=event_data.params["request"].get("request", None), - body_size=event_data.params["request"].get("bodySize", None), - cookies=event_data.params["request"].get("cookies", None), - resource_type=event_data.params["request"].get("goog:resourceType", None), - headers=event_data.params["request"].get("headers", None), - headers_size=event_data.params["request"].get("headersSize", None), - timings=event_data.params["request"].get("timings", None), - url=event_data.params["request"].get("url", None), - ) - callback(request) + from selenium.webdriver.common.bidi.common import command_builder as _cb - callback_id: int = self.conn.add_callback(event, _callback) - - if event_name in self.callbacks: - self.callbacks[event_name].append(callback_id) - else: - self.callbacks[event_name] = [callback_id] - - return callback_id + if phases is None: + phases = ["beforeRequestSent"] + params = {"phases": phases} + if url_patterns: + params["urlPatterns"] = url_patterns + result = self._conn.execute(_cb("network.addIntercept", params)) + if result: + intercept_id = result.get("intercept") + if intercept_id and intercept_id not in self.intercepts: + self.intercepts.append(intercept_id) + return result + def _remove_intercept(self, intercept_id): + """Remove a low-level network intercept.""" + from selenium.webdriver.common.bidi.common import command_builder as _cb - def add_request_handler( - self, - event: str, - callback: Callable[[Request], Any], - url_patterns: list[Any] | None = None, - contexts: list[str] | None = None, - ) -> int: - """Add a request handler to the network. + self._conn.execute(_cb("network.removeIntercept", {"intercept": intercept_id})) + if intercept_id in self.intercepts: + self.intercepts.remove(intercept_id) + def add_request_handler(self, event, callback, url_patterns=None): + """Add a handler for network requests at the specified phase. Args: - event: The event to subscribe to. - callback: The callback function to execute on request interception. - Takes Request object as argument. - url_patterns: A list of URL patterns to intercept. Default is None. - contexts: A list of contexts to intercept. Default is None. + event: Event name, e.g. ``"before_request"``. + callback: Callable receiving a :class:`Request` instance. + url_patterns: optional list of URL pattern dicts to filter. Returns: - int: callback id + callback_id int for later removal via remove_request_handler. """ - try: - event_name = self.EVENTS[event] - phase_name = self.PHASES[event] - except KeyError: - raise Exception(f"Event {event} not found") - - result = self._add_intercept(phases=[phase_name], url_patterns=url_patterns, contexts=contexts) - callback_id = self._on_request(event_name, callback) - - if event_name in self.subscriptions: - self.subscriptions[event_name].append(callback_id) - else: - params: dict[str, Any] = {} - params["events"] = [event_name] - self.conn.execute(command_builder("session.subscribe", params)) - self.subscriptions[event_name] = [callback_id] - - self.callbacks[callback_id] = result["intercept"] - return callback_id + phase_map = { + "before_request": "beforeRequestSent", + "before_request_sent": "beforeRequestSent", + "response_started": "responseStarted", + "auth_required": "authRequired", + } + phase = phase_map.get(event, "beforeRequestSent") + intercept_result = self._add_intercept(phases=[phase], url_patterns=url_patterns) + intercept_id = intercept_result.get("intercept") if intercept_result else None - def remove_request_handler(self, event: str, callback_id: int) -> None: - """Remove a request handler from the network. + def _request_callback(params): + raw = ( + params + if isinstance(params, dict) + else (params.__dict__ if hasattr(params, "__dict__") else {}) + ) + request = Request(self._conn, raw) + callback(request) + + callback_id = self.add_event_handler(event, _request_callback) + if intercept_id: + self._handler_intercepts[callback_id] = intercept_id + return callback_id + def remove_request_handler(self, event, callback_id): + """Remove a network request handler and its associated network intercept. Args: - event: The event to unsubscribe from. - callback_id: The callback id to remove. + event: The event name used when adding the handler. + callback_id: The int returned by add_request_handler. """ - try: - event_name = self.EVENTS[event] - except KeyError: - raise Exception(f"Event {event} not found") - - net_event = NetworkEvent(event_name) - - self.conn.remove_callback(net_event, callback_id) - self._remove_intercept(self.callbacks[callback_id]) - del self.callbacks[callback_id] - self.subscriptions[event_name].remove(callback_id) - if len(self.subscriptions[event_name]) == 0: - params: dict[str, Any] = {} - params["events"] = [event_name] - self.conn.execute(command_builder("session.unsubscribe", params)) - del self.subscriptions[event_name] - - def clear_request_handlers(self) -> None: - """Clear all request handlers from the network.""" - for event_name in self.subscriptions: - net_event = NetworkEvent(event_name) - for callback_id in self.subscriptions[event_name]: - self.conn.remove_callback(net_event, callback_id) - self._remove_intercept(self.callbacks[callback_id]) - del self.callbacks[callback_id] - params: dict[str, Any] = {} - params["events"] = [event_name] - self.conn.execute(command_builder("session.unsubscribe", params)) - self.subscriptions = {} - - def add_auth_handler(self, username: str, password: str) -> int: - """Add an authentication handler to the network. + self.remove_event_handler(event, callback_id) + intercept_id = self._handler_intercepts.pop(callback_id, None) + if intercept_id: + self._remove_intercept(intercept_id) + def clear_request_handlers(self): + """Clear all request handlers and remove all tracked intercepts.""" + self.clear_event_handlers() + for intercept_id in list(self.intercepts): + self._remove_intercept(intercept_id) + def add_auth_handler(self, username, password): + """Add an auth handler that automatically provides credentials. Args: - username: The username to authenticate with. - password: The password to authenticate with. + username: The username for basic authentication. + password: The password for basic authentication. Returns: - int: callback id + callback_id int for later removal via remove_auth_handler. """ - event = "auth_required" + from selenium.webdriver.common.bidi.common import command_builder as _cb - def _callback(request: Request) -> None: - request._continue_with_auth(username, password) + # Set up network intercept for authRequired phase + intercept_result = self._add_intercept(phases=["authRequired"]) + intercept_id = intercept_result.get("intercept") if intercept_result else None - return self.add_request_handler(event, _callback) + def _auth_callback(params): + raw = ( + params + if isinstance(params, dict) + else (params.__dict__ if hasattr(params, "__dict__") else {}) + ) + request_id = ( + raw.get("request", {}).get("request") + if isinstance(raw, dict) + else None + ) + if request_id: + self._conn.execute( + _cb( + "network.continueWithAuth", + { + "request": request_id, + "action": "provideCredentials", + "credentials": { + "type": "password", + "username": username, + "password": password, + }, + }, + ) + ) - def remove_auth_handler(self, callback_id: int) -> None: - """Remove an authentication handler from the network. + callback_id = self.add_event_handler("auth_required", _auth_callback) + if intercept_id: + self._handler_intercepts[callback_id] = intercept_id + return callback_id + def remove_auth_handler(self, callback_id): + """Remove an auth handler by callback ID and its associated network intercept. Args: - callback_id: The callback id to remove. + callback_id: The handler ID returned by add_auth_handler. """ - event = "auth_required" - self.remove_request_handler(event, callback_id) + self.remove_event_handler("auth_required", callback_id) + intercept_id = self._handler_intercepts.pop(callback_id, None) + if intercept_id: + self._remove_intercept(intercept_id) + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + """Add an event handler. -class Request: - """Represents an intercepted network request.""" + Args: + event: The event to subscribe to. + callback: The callback function to execute on event. + contexts: The context IDs to subscribe to (optional). - def __init__( - self, - network: Network, - request_id: Any, - body_size: int | None = None, - cookies: Any = None, - resource_type: str | None = None, - headers: Any = None, - headers_size: int | None = None, - method: str | None = None, - timings: Any = None, - url: str | None = None, - ) -> None: - self.network = network - self.request_id = request_id - self.body_size = body_size - self.cookies = cookies - self.resource_type = resource_type - self.headers = headers - self.headers_size = headers_size - self.method = method - self.timings = timings - self.url = url - - def fail_request(self) -> None: - """Fail this request.""" - if not self.request_id: - raise ValueError("Request not found.") - - params: dict[str, Any] = {"request": self.request_id} - self.network.conn.execute(command_builder("network.failRequest", params)) + Returns: + The callback ID. + """ + return self._event_manager.add_event_handler(event, callback, contexts) - def continue_request( - self, - body: Any = None, - method: str | None = None, - headers: Any = None, - cookies: Any = None, - url: str | None = None, - ) -> None: - """Continue after intercepting this request.""" - if not self.request_id: - raise ValueError("Request not found.") - - params: dict[str, Any] = {"request": self.request_id} - if body is not None: - params["body"] = body - if method is not None: - params["method"] = method - if headers is not None: - params["headers"] = headers - if cookies is not None: - params["cookies"] = cookies - if url is not None: - params["url"] = url - - self.network.conn.execute(command_builder("network.continueRequest", params)) - - def _continue_with_auth(self, username: str | None = None, password: str | None = None) -> None: - """Continue with authentication. + def remove_event_handler(self, event: str, callback_id: int) -> None: + """Remove an event handler. Args: - username: The username to authenticate with. - password: The password to authenticate with. - - Note: - If username or password is None, it attempts auth with no credentials. + event: The event to unsubscribe from. + callback_id: The callback ID. """ - params: dict[str, Any] = {} - params["request"] = self.request_id + return self._event_manager.remove_event_handler(event, callback_id) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + return self._event_manager.clear_event_handlers() + +# Event Info Type Aliases +# Event: network.authRequired +AuthRequired = globals().get('AuthRequiredParameters', dict) # Fallback to dict if type not defined - if not username or not password: # no credentials is valid option - params["action"] = "default" - else: - params["action"] = "provideCredentials" - params["credentials"] = {"type": "password", "username": username, "password": password} - self.network.conn.execute(command_builder("network.continueWithAuth", params)) +# Populate EVENT_CONFIGS with event configuration mappings +_globals = globals() +Network.EVENT_CONFIGS = { + "auth_required": EventConfig( + "auth_required", + "network.authRequired", + _globals.get("AuthRequired", dict) if _globals.get("AuthRequired") else dict, + ), + "before_request": EventConfig("before_request", "network.beforeRequestSent", _globals.get("dict", dict)), +} diff --git a/py/selenium/webdriver/common/bidi/permissions.py b/py/selenium/webdriver/common/bidi/permissions.py index 17faa1ff5454f..6dd138da17309 100644 --- a/py/selenium/webdriver/common/bidi/permissions.py +++ b/py/selenium/webdriver/common/bidi/permissions.py @@ -15,12 +15,20 @@ # specific language governing permissions and limitations # under the License. +"""WebDriver BiDi Permissions module.""" -from selenium.webdriver.common.bidi.common import command_builder +from __future__ import annotations +from enum import Enum +from typing import Any -class PermissionState: - """Represents the possible permission states.""" +from .common import command_builder + +_VALID_PERMISSION_STATES = {"granted", "denied", "prompt"} + + +class PermissionState(str, Enum): + """Permission state enumeration.""" GRANTED = "granted" DENIED = "denied" @@ -28,56 +36,69 @@ class PermissionState: class PermissionDescriptor: - """Represents a permission descriptor.""" + """Descriptor for a permission.""" - def __init__(self, name: str): + def __init__(self, name: str) -> None: + """Initialize a PermissionDescriptor. + + Args: + name: The name of the permission (e.g., 'geolocation', 'microphone', 'camera') + """ self.name = name - def to_dict(self) -> dict: - return {"name": self.name} + def __repr__(self) -> str: + return f"PermissionDescriptor('{self.name}')" class Permissions: - """BiDi implementation of the permissions module.""" + """WebDriver BiDi Permissions module.""" - def __init__(self, conn): - self.conn = conn + def __init__(self, websocket_connection: Any) -> None: + """Initialize the Permissions module. + + Args: + websocket_connection: The WebSocket connection for sending BiDi commands + """ + self._conn = websocket_connection def set_permission( self, - descriptor: str | PermissionDescriptor, - state: str, - origin: str, + descriptor: PermissionDescriptor | str, + state: PermissionState | str, + origin: str | None = None, user_context: str | None = None, ) -> None: - """Sets a permission state for a given permission descriptor. + """Set a permission for a given origin. Args: - descriptor: The permission name (str) or PermissionDescriptor object. - Examples: "geolocation", "camera", "microphone". - state: The permission state (granted, denied, prompt). - origin: The origin for which the permission is set. - user_context: The user context id (optional). + descriptor: The permission descriptor or permission name as a string + state: The desired permission state + origin: The origin for which to set the permission + user_context: Optional user context ID to scope the permission Raises: - ValueError: If the permission state is invalid. + ValueError: If the state is not a valid permission state """ - if state not in [PermissionState.GRANTED, PermissionState.DENIED, PermissionState.PROMPT]: - valid_states = f"{PermissionState.GRANTED}, {PermissionState.DENIED}, {PermissionState.PROMPT}" - raise ValueError(f"Invalid permission state. Must be one of: {valid_states}") + state_value = state.value if isinstance(state, PermissionState) else state + if state_value not in _VALID_PERMISSION_STATES: + raise ValueError( + f"Invalid permission state: {state_value!r}. " + f"Must be one of {sorted(_VALID_PERMISSION_STATES)}" + ) if isinstance(descriptor, str): - permission_descriptor = PermissionDescriptor(descriptor) + descriptor_dict = {"name": descriptor} else: - permission_descriptor = descriptor + descriptor_dict = {"name": descriptor.name} - params = { - "descriptor": permission_descriptor.to_dict(), - "state": state, - "origin": origin, + params: dict[str, Any] = { + "descriptor": descriptor_dict, + "state": state_value, } - + if origin is not None: + params["origin"] = origin if user_context is not None: params["userContext"] = user_context - self.conn.execute(command_builder("permissions.setPermission", params)) + cmd = command_builder("permissions.setPermission", params) + self._conn.execute(cmd) diff --git a/py/selenium/webdriver/common/bidi/py.typed b/py/selenium/webdriver/common/bidi/py.typed new file mode 100755 index 0000000000000..e69de29bb2d1d diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index e37b3269a4ade..5a7d2792a1221 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -1,40 +1,31 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import datetime -import math -from dataclasses import dataclass -from typing import Any +# WebDriver BiDi module: script +from __future__ import annotations -from selenium.common.exceptions import WebDriverException -from selenium.webdriver.common.bidi.common import command_builder -from selenium.webdriver.common.bidi.log import LogEntryAdded +from typing import Any +from .common import command_builder +from dataclasses import dataclass +from dataclasses import field +import threading +from collections.abc import Callable from selenium.webdriver.common.bidi.session import Session -class ResultOwnership: - """Represents the possible result ownership types.""" +class SpecialNumber: + """SpecialNumber.""" - NONE = "none" - ROOT = "root" + NAN = "NaN" + _0 = "-0" + INFINITY = "Infinity" + _INFINITY = "-Infinity" class RealmType: - """Represents the possible realm types.""" + """RealmType.""" WINDOW = "window" DEDICATED_WORKER = "dedicated-worker" @@ -46,502 +37,1314 @@ class RealmType: WORKLET = "worklet" +class ResultOwnership: + """ResultOwnership.""" + + ROOT = "root" + NONE = "none" + + +@dataclass +class ChannelValue: + """ChannelValue.""" + + type: str = field(default="channel", init=False) + value: Any | None = None + + +@dataclass +class ChannelProperties: + """ChannelProperties.""" + + channel: Any | None = None + serialization_options: Any | None = None + ownership: Any | None = None + + +@dataclass +class EvaluateResultSuccess: + """EvaluateResultSuccess.""" + + type: str = field(default="success", init=False) + result: Any | None = None + realm: Any | None = None + + +@dataclass +class EvaluateResultException: + """EvaluateResultException.""" + + type: str = field(default="exception", init=False) + exception_details: Any | None = None + realm: Any | None = None + + +@dataclass +class ExceptionDetails: + """ExceptionDetails.""" + + column_number: Any | None = None + exception: Any | None = None + line_number: Any | None = None + stack_trace: Any | None = None + text: str | None = None + + +@dataclass +class ArrayLocalValue: + """ArrayLocalValue.""" + + type: str = field(default="array", init=False) + value: Any | None = None + + +@dataclass +class DateLocalValue: + """DateLocalValue.""" + + type: str = field(default="date", init=False) + value: str | None = None + + +@dataclass +class MapLocalValue: + """MapLocalValue.""" + + type: str = field(default="map", init=False) + value: Any | None = None + + @dataclass -class RealmInfo: - """Represents information about a realm.""" +class ObjectLocalValue: + """ObjectLocalValue.""" - realm: str - origin: str - type: str - context: str | None = None + type: str = field(default="object", init=False) + value: Any | None = None + + +@dataclass +class RegExpValue: + """RegExpValue.""" + + pattern: str | None = None + flags: str | None = None + + +@dataclass +class RegExpLocalValue: + """RegExpLocalValue.""" + + type: str = field(default="regexp", init=False) + value: Any | None = None + + +@dataclass +class SetLocalValue: + """SetLocalValue.""" + + type: str = field(default="set", init=False) + value: Any | None = None + + +@dataclass +class UndefinedValue: + """UndefinedValue.""" + + type: str = field(default="undefined", init=False) + + +@dataclass +class NullValue: + """NullValue.""" + + type: str = field(default="null", init=False) + + +@dataclass +class StringValue: + """StringValue.""" + + type: str = field(default="string", init=False) + value: str | None = None + + +@dataclass +class NumberValue: + """NumberValue.""" + + type: str = field(default="number", init=False) + value: Any | None = None + + +@dataclass +class BooleanValue: + """BooleanValue.""" + + type: str = field(default="boolean", init=False) + value: bool | None = None + + +@dataclass +class BigIntValue: + """BigIntValue.""" + + type: str = field(default="bigint", init=False) + value: str | None = None + + +@dataclass +class BaseRealmInfo: + """BaseRealmInfo.""" + + realm: Any | None = None + origin: str | None = None + + +@dataclass +class WindowRealmInfo: + """WindowRealmInfo.""" + + type: str = field(default="window", init=False) + context: Any | None = None sandbox: str | None = None - @classmethod - def from_json(cls, json: dict[str, Any]) -> "RealmInfo": - """Creates a RealmInfo instance from a dictionary. - Args: - json: A dictionary containing the realm information. +@dataclass +class DedicatedWorkerRealmInfo: + """DedicatedWorkerRealmInfo.""" - Returns: - RealmInfo: A new instance of RealmInfo. - """ - if "realm" not in json: - raise ValueError("Missing required field 'realm' in RealmInfo") - if "origin" not in json: - raise ValueError("Missing required field 'origin' in RealmInfo") - if "type" not in json: - raise ValueError("Missing required field 'type' in RealmInfo") - - return cls( - realm=json["realm"], - origin=json["origin"], - type=json["type"], - context=json.get("context"), - sandbox=json.get("sandbox"), - ) + type: str = field(default="dedicated-worker", init=False) + owners: list[Any] = field(default_factory=list) @dataclass -class Source: - """Represents the source of a script message.""" +class SharedWorkerRealmInfo: + """SharedWorkerRealmInfo.""" - realm: str - context: str | None = None + type: str = field(default="shared-worker", init=False) - @classmethod - def from_json(cls, json: dict[str, Any]) -> "Source": - """Creates a Source instance from a dictionary. - Args: - json: A dictionary containing the source information. +@dataclass +class ServiceWorkerRealmInfo: + """ServiceWorkerRealmInfo.""" - Returns: - Source: A new instance of Source. - """ - if "realm" not in json: - raise ValueError("Missing required field 'realm' in Source") + type: str = field(default="service-worker", init=False) - return cls( - realm=json["realm"], - context=json.get("context"), - ) + +@dataclass +class WorkerRealmInfo: + """WorkerRealmInfo.""" + + type: str = field(default="worker", init=False) @dataclass -class EvaluateResult: - """Represents the result of script evaluation.""" +class PaintWorkletRealmInfo: + """PaintWorkletRealmInfo.""" - type: str - realm: str - result: dict | None = None - exception_details: dict | None = None + type: str = field(default="paint-worklet", init=False) - @classmethod - def from_json(cls, json: dict[str, Any]) -> "EvaluateResult": - """Creates an EvaluateResult instance from a dictionary. - Args: - json: A dictionary containing the evaluation result. +@dataclass +class AudioWorkletRealmInfo: + """AudioWorkletRealmInfo.""" - Returns: - EvaluateResult: A new instance of EvaluateResult. - """ - if "realm" not in json: - raise ValueError("Missing required field 'realm' in EvaluateResult") - if "type" not in json: - raise ValueError("Missing required field 'type' in EvaluateResult") - - return cls( - type=json["type"], - realm=json["realm"], - result=json.get("result"), - exception_details=json.get("exceptionDetails"), - ) + type: str = field(default="audio-worklet", init=False) -class ScriptMessage: - """Represents a script message event.""" +@dataclass +class WorkletRealmInfo: + """WorkletRealmInfo.""" - event_class = "script.message" + type: str = field(default="worklet", init=False) - def __init__(self, channel: str, data: dict, source: Source): - self.channel = channel - self.data = data - self.source = source - @classmethod - def from_json(cls, json: dict[str, Any]) -> "ScriptMessage": - """Creates a ScriptMessage instance from a dictionary. +@dataclass +class SharedReference: + """SharedReference.""" - Args: - json: A dictionary containing the script message. + shared_id: Any | None = None + handle: Any | None = None - Returns: - ScriptMessage: A new instance of ScriptMessage. - """ - if "channel" not in json: - raise ValueError("Missing required field 'channel' in ScriptMessage") - if "data" not in json: - raise ValueError("Missing required field 'data' in ScriptMessage") - if "source" not in json: - raise ValueError("Missing required field 'source' in ScriptMessage") - - return cls( - channel=json["channel"], - data=json["data"], - source=Source.from_json(json["source"]), - ) +@dataclass +class RemoteObjectReference: + """RemoteObjectReference.""" -class RealmCreated: - """Represents a realm created event.""" + handle: Any | None = None + shared_id: Any | None = None - event_class = "script.realmCreated" - def __init__(self, realm_info: RealmInfo): - self.realm_info = realm_info +@dataclass +class SymbolRemoteValue: + """SymbolRemoteValue.""" - @classmethod - def from_json(cls, json: dict[str, Any]) -> "RealmCreated": - """Creates a RealmCreated instance from a dictionary. + type: str = field(default="symbol", init=False) + handle: Any | None = None + internal_id: Any | None = None - Args: - json: A dictionary containing the realm created event. - Returns: - RealmCreated: A new instance of RealmCreated. - """ - return cls(realm_info=RealmInfo.from_json(json)) +@dataclass +class ArrayRemoteValue: + """ArrayRemoteValue.""" + type: str = field(default="array", init=False) + handle: Any | None = None + internal_id: Any | None = None + value: Any | None = None -class RealmDestroyed: - """Represents a realm destroyed event.""" - event_class = "script.realmDestroyed" +@dataclass +class ObjectRemoteValue: + """ObjectRemoteValue.""" - def __init__(self, realm: str): - self.realm = realm + type: str = field(default="object", init=False) + handle: Any | None = None + internal_id: Any | None = None + value: Any | None = None - @classmethod - def from_json(cls, json: dict[str, Any]) -> "RealmDestroyed": - """Creates a RealmDestroyed instance from a dictionary. - Args: - json: A dictionary containing the realm destroyed event. +@dataclass +class FunctionRemoteValue: + """FunctionRemoteValue.""" - Returns: - RealmDestroyed: A new instance of RealmDestroyed. - """ - if "realm" not in json: - raise ValueError("Missing required field 'realm' in RealmDestroyed") + type: str = field(default="function", init=False) + handle: Any | None = None + internal_id: Any | None = None - return cls(realm=json["realm"]) +@dataclass +class RegExpRemoteValue: + """RegExpRemoteValue.""" -class Script: - """BiDi implementation of the script module.""" + handle: Any | None = None + internal_id: Any | None = None - EVENTS = { - "message": "script.message", - "realm_created": "script.realmCreated", - "realm_destroyed": "script.realmDestroyed", - } - def __init__(self, conn, driver=None): - self.conn = conn - self.driver = driver - self.log_entry_subscribed = False - self.subscriptions = {} - self.callbacks = {} +@dataclass +class DateRemoteValue: + """DateRemoteValue.""" + + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class MapRemoteValue: + """MapRemoteValue.""" + + type: str = field(default="map", init=False) + handle: Any | None = None + internal_id: Any | None = None + value: Any | None = None + - # High-level APIs for SCRIPT module +@dataclass +class SetRemoteValue: + """SetRemoteValue.""" + + type: str = field(default="set", init=False) + handle: Any | None = None + internal_id: Any | None = None + value: Any | None = None + + +@dataclass +class WeakMapRemoteValue: + """WeakMapRemoteValue.""" + + type: str = field(default="weakmap", init=False) + handle: Any | None = None + internal_id: Any | None = None - def add_console_message_handler(self, handler): - self._subscribe_to_log_entries() - return self.conn.add_callback(LogEntryAdded, self._handle_log_entry("console", handler)) - def add_javascript_error_handler(self, handler): - self._subscribe_to_log_entries() - return self.conn.add_callback(LogEntryAdded, self._handle_log_entry("javascript", handler)) +@dataclass +class WeakSetRemoteValue: + """WeakSetRemoteValue.""" + + type: str = field(default="weakset", init=False) + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class GeneratorRemoteValue: + """GeneratorRemoteValue.""" + + type: str = field(default="generator", init=False) + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class ErrorRemoteValue: + """ErrorRemoteValue.""" + + type: str = field(default="error", init=False) + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class ProxyRemoteValue: + """ProxyRemoteValue.""" + + type: str = field(default="proxy", init=False) + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class PromiseRemoteValue: + """PromiseRemoteValue.""" - def remove_console_message_handler(self, id): - self.conn.remove_callback(LogEntryAdded, id) - self._unsubscribe_from_log_entries() + type: str = field(default="promise", init=False) + handle: Any | None = None + internal_id: Any | None = None - remove_javascript_error_handler = remove_console_message_handler - def pin(self, script: str) -> str: - """Pins a script to the current browsing context. +@dataclass +class TypedArrayRemoteValue: + """TypedArrayRemoteValue.""" + + type: str = field(default="typedarray", init=False) + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class ArrayBufferRemoteValue: + """ArrayBufferRemoteValue.""" + + type: str = field(default="arraybuffer", init=False) + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class NodeListRemoteValue: + """NodeListRemoteValue.""" + + type: str = field(default="nodelist", init=False) + handle: Any | None = None + internal_id: Any | None = None + value: Any | None = None + + +@dataclass +class HTMLCollectionRemoteValue: + """HTMLCollectionRemoteValue.""" + + type: str = field(default="htmlcollection", init=False) + handle: Any | None = None + internal_id: Any | None = None + value: Any | None = None + + +@dataclass +class NodeRemoteValue: + """NodeRemoteValue.""" + + type: str = field(default="node", init=False) + shared_id: Any | None = None + handle: Any | None = None + internal_id: Any | None = None + value: Any | None = None + + +@dataclass +class NodeProperties: + """NodeProperties.""" + + node_type: Any | None = None + child_node_count: Any | None = None + children: list[Any] = field(default_factory=list) + local_name: str | None = None + mode: Any | None = None + namespace_uri: str | None = None + node_value: str | None = None + shadow_root: Any | None = None + + +@dataclass +class WindowProxyRemoteValue: + """WindowProxyRemoteValue.""" + + type: str = field(default="window", init=False) + value: Any | None = None + handle: Any | None = None + internal_id: Any | None = None + + +@dataclass +class WindowProxyProperties: + """WindowProxyProperties.""" + + context: Any | None = None + + +@dataclass +class StackFrame: + """StackFrame.""" + + column_number: Any | None = None + function_name: str | None = None + line_number: Any | None = None + url: str | None = None + + +@dataclass +class StackTrace: + """StackTrace.""" + + call_frames: list[Any] = field(default_factory=list) + + +@dataclass +class Source: + """Source.""" + + realm: Any | None = None + context: Any | None = None + + +@dataclass +class RealmTarget: + """RealmTarget.""" + + realm: Any | None = None + + +@dataclass +class ContextTarget: + """ContextTarget.""" + + context: Any | None = None + sandbox: str | None = None + + +@dataclass +class AddPreloadScriptParameters: + """AddPreloadScriptParameters.""" + + function_declaration: str | None = None + arguments: list[Any] = field(default_factory=list) + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) + sandbox: str | None = None + + +@dataclass +class AddPreloadScriptResult: + """AddPreloadScriptResult.""" + + script: Any | None = None + + +@dataclass +class DisownParameters: + """DisownParameters.""" + + handles: list[Any] = field(default_factory=list) + target: Any | None = None + + +@dataclass +class CallFunctionParameters: + """CallFunctionParameters.""" + + function_declaration: str | None = None + await_promise: bool | None = None + target: Any | None = None + arguments: list[Any] = field(default_factory=list) + result_ownership: Any | None = None + serialization_options: Any | None = None + this: Any | None = None + user_activation: bool | None = None + + +@dataclass +class EvaluateParameters: + """EvaluateParameters.""" + + expression: str | None = None + target: Any | None = None + await_promise: bool | None = None + result_ownership: Any | None = None + serialization_options: Any | None = None + user_activation: bool | None = None + + +@dataclass +class GetRealmsParameters: + """GetRealmsParameters.""" + + context: Any | None = None + type: Any | None = None + + +@dataclass +class GetRealmsResult: + """GetRealmsResult.""" + + realms: list[Any] = field(default_factory=list) + + +@dataclass +class RemovePreloadScriptParameters: + """RemovePreloadScriptParameters.""" + + script: Any | None = None + + +@dataclass +class MessageParameters: + """MessageParameters.""" + + channel: Any | None = None + data: Any | None = None + source: Any | None = None + + +@dataclass +class RealmDestroyedParameters: + """RealmDestroyedParameters.""" + + realm: Any | None = None + + +# BiDi Event Name to Parameter Type Mapping +EVENT_NAME_MAPPING = { + "realm_created": "script.realmCreated", + "realm_destroyed": "script.realmDestroyed", +} + +@dataclass +class EventConfig: + """Configuration for a BiDi event.""" + event_key: str + bidi_event: str + event_class: type + + +class _EventWrapper: + """Wrapper to provide event_class attribute for WebSocketConnection callbacks.""" + def __init__(self, bidi_event: str, event_class: type): + self.event_class = bidi_event # WebSocket expects the BiDi event name as event_class + self._python_class = event_class # Keep reference to Python dataclass for deserialization + + def from_json(self, params: dict) -> Any: + """Deserialize event params into the wrapped Python dataclass. Args: - script: The script to pin. + params: Raw BiDi event params with camelCase keys. Returns: - str: The ID of the pinned script. + An instance of the dataclass, or the raw dict on failure. """ - return self._add_preload_script(script) + if self._python_class is None or self._python_class is dict: + return params + try: + # Delegate to a classmethod from_json if the class defines one + if hasattr(self._python_class, "from_json") and callable( + self._python_class.from_json + ): + return self._python_class.from_json(params) + import dataclasses as dc + + snake_params = {self._camel_to_snake(k): v for k, v in params.items()} + if dc.is_dataclass(self._python_class): + valid_fields = {f.name for f in dc.fields(self._python_class)} + filtered = {k: v for k, v in snake_params.items() if k in valid_fields} + return self._python_class(**filtered) + return self._python_class(**snake_params) + except Exception: + return params + + @staticmethod + def _camel_to_snake(name: str) -> str: + result = [name[0].lower()] + for char in name[1:]: + if char.isupper(): + result.extend(["_", char.lower()]) + else: + result.append(char) + return "".join(result) + + +class _EventManager: + """Manages event subscriptions and callbacks.""" + + def __init__(self, conn, event_configs: dict[str, EventConfig]): + self.conn = conn + self.event_configs = event_configs + self.subscriptions: dict = {} + self._event_wrappers = {} # Cache of _EventWrapper objects + self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()} + self._available_events = ", ".join(sorted(event_configs.keys())) + self._subscription_lock = threading.Lock() + + # Create event wrappers for each event + for config in event_configs.values(): + wrapper = _EventWrapper(config.bidi_event, config.event_class) + self._event_wrappers[config.bidi_event] = wrapper + + def validate_event(self, event: str) -> EventConfig: + event_config = self.event_configs.get(event) + if not event_config: + raise ValueError(f"Event '{event}' not found. Available events: {self._available_events}") + return event_config + + def subscribe_to_event(self, bidi_event: str, contexts: list[str] | None = None) -> None: + """Subscribe to a BiDi event if not already subscribed.""" + with self._subscription_lock: + if bidi_event not in self.subscriptions: + session = Session(self.conn) + result = session.subscribe([bidi_event], contexts=contexts) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self.subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + + def unsubscribe_from_event(self, bidi_event: str) -> None: + """Unsubscribe from a BiDi event if no more callbacks exist.""" + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry is not None and not entry["callbacks"]: + session = Session(self.conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self.subscriptions[bidi_event] + + def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + self.subscriptions[bidi_event]["callbacks"].append(callback_id) + + def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None: + with self._subscription_lock: + entry = self.subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + event_config = self.validate_event(event) + # Use the event wrapper for add_callback + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + callback_id = self.conn.add_callback(event_wrapper, callback) + self.subscribe_to_event(event_config.bidi_event, contexts) + self.add_callback_to_tracking(event_config.bidi_event, callback_id) + return callback_id + + def remove_event_handler(self, event: str, callback_id: int) -> None: + event_config = self.validate_event(event) + event_wrapper = self._event_wrappers.get(event_config.bidi_event) + self.conn.remove_callback(event_wrapper, callback_id) + self.remove_callback_from_tracking(event_config.bidi_event, callback_id) + self.unsubscribe_from_event(event_config.bidi_event) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + with self._subscription_lock: + if not self.subscriptions: + return + session = Session(self.conn) + for bidi_event, entry in list(self.subscriptions.items()): + event_wrapper = self._event_wrappers.get(bidi_event) + callbacks = entry["callbacks"] if isinstance(entry, dict) else entry + if event_wrapper: + for callback_id in callbacks: + self.conn.remove_callback(event_wrapper, callback_id) + sub_id = ( + entry.get("subscription_id") if isinstance(entry, dict) else None + ) + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + self.subscriptions.clear() - def unpin(self, script_id: str) -> None: - """Unpins a script from the current browsing context. - Args: - script_id: The ID of the pinned script to unpin. - """ - self._remove_preload_script(script_id) - def execute(self, script: str, *args) -> dict: - """Executes a script in the current browsing context. - Args: - script: The script function to execute. - *args: Arguments to pass to the script function. +class Script: + """WebDriver BiDi script module.""" - Returns: - dict: The result value from the script execution. + EVENT_CONFIGS: dict[str, EventConfig] = {} + def __init__(self, conn, driver=None) -> None: + self._conn = conn + self._driver = driver + self._event_manager = _EventManager(conn, self.EVENT_CONFIGS) - Raises: - WebDriverException: If the script execution fails. - """ - if self.driver is None: - raise WebDriverException("Driver reference is required for script execution") - browsing_context_id = self.driver.current_window_handle + def add_preload_script( + self, + function_declaration: Any | None = None, + arguments: list[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + sandbox: Any | None = None, + ): + """Execute script.addPreloadScript.""" + if function_declaration is None: + raise TypeError("add_preload_script() missing required argument: {{snake_param!r}}") - # Convert arguments to the format expected by BiDi call_function (LocalValue Type) - arguments = [] - for arg in args: - arguments.append(self.__convert_to_local_value(arg)) + params = { + "functionDeclaration": function_declaration, + "arguments": arguments, + "contexts": contexts, + "userContexts": user_contexts, + "sandbox": sandbox, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.addPreloadScript", params) + result = self._conn.execute(cmd) + return result + + def disown(self, handles: list[Any] | None = None, target: Any | None = None): + """Execute script.disown.""" + if handles is None: + raise TypeError("disown() missing required argument: {{snake_param!r}}") + if target is None: + raise TypeError("disown() missing required argument: {{snake_param!r}}") - target = {"context": browsing_context_id} + params = { + "handles": handles, + "target": target, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.disown", params) + result = self._conn.execute(cmd) + return result - result = self._call_function( - function_declaration=script, await_promise=True, target=target, arguments=arguments if arguments else None - ) + def call_function( + self, + function_declaration: Any | None = None, + await_promise: bool | None = None, + target: Any | None = None, + arguments: list[Any] | None = None, + result_ownership: Any | None = None, + serialization_options: Any | None = None, + this: Any | None = None, + user_activation: bool | None = None, + ): + """Execute script.callFunction.""" + if function_declaration is None: + raise TypeError("call_function() missing required argument: {{snake_param!r}}") + if await_promise is None: + raise TypeError("call_function() missing required argument: {{snake_param!r}}") + if target is None: + raise TypeError("call_function() missing required argument: {{snake_param!r}}") + + params = { + "functionDeclaration": function_declaration, + "awaitPromise": await_promise, + "target": target, + "arguments": arguments, + "resultOwnership": result_ownership, + "serializationOptions": serialization_options, + "this": this, + "userActivation": user_activation, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.callFunction", params) + result = self._conn.execute(cmd) + return result + + def evaluate( + self, + expression: Any | None = None, + target: Any | None = None, + await_promise: bool | None = None, + result_ownership: Any | None = None, + serialization_options: Any | None = None, + user_activation: bool | None = None, + ): + """Execute script.evaluate.""" + if expression is None: + raise TypeError("evaluate() missing required argument: {{snake_param!r}}") + if target is None: + raise TypeError("evaluate() missing required argument: {{snake_param!r}}") + if await_promise is None: + raise TypeError("evaluate() missing required argument: {{snake_param!r}}") + + params = { + "expression": expression, + "target": target, + "awaitPromise": await_promise, + "resultOwnership": result_ownership, + "serializationOptions": serialization_options, + "userActivation": user_activation, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.evaluate", params) + result = self._conn.execute(cmd) + return result - if result.type == "success": - return result.result if result.result is not None else {} - else: - error_message = "Error while executing script" - if result.exception_details: - if "text" in result.exception_details: - error_message += f": {result.exception_details['text']}" - elif "message" in result.exception_details: - error_message += f": {result.exception_details['message']}" - - raise WebDriverException(error_message) - - def __convert_to_local_value(self, value) -> dict: - """Converts a Python value to BiDi LocalValue format.""" - if value is None: - return {"type": "null"} - elif isinstance(value, bool): - return {"type": "boolean", "value": value} - elif isinstance(value, (int, float)): + def get_realms(self, context: Any | None = None, type: Any | None = None): + """Execute script.getRealms.""" + params = { + "context": context, + "type": type, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.getRealms", params) + result = self._conn.execute(cmd) + return result + + def remove_preload_script(self, script: Any | None = None): + """Execute script.removePreloadScript.""" + if script is None: + raise TypeError("remove_preload_script() missing required argument: {{snake_param!r}}") + + params = { + "script": script, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.removePreloadScript", params) + result = self._conn.execute(cmd) + return result + + def message(self, channel: Any | None = None, data: Any | None = None, source: Any | None = None): + """Execute script.message.""" + if channel is None: + raise TypeError("message() missing required argument: {{snake_param!r}}") + if data is None: + raise TypeError("message() missing required argument: {{snake_param!r}}") + if source is None: + raise TypeError("message() missing required argument: {{snake_param!r}}") + + params = { + "channel": channel, + "data": data, + "source": source, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("script.message", params) + result = self._conn.execute(cmd) + return result + + def execute(self, function_declaration: str, *args, context_id: str | None = None) -> Any: + """Execute a function declaration in the browser context. + + Args: + function_declaration: The function as a string, e.g. ``"() => document.title"``. + *args: Optional Python values to pass as arguments to the function. + Each value is serialised to a BiDi ``LocalValue`` automatically. + Supported types: ``None``, ``bool``, ``int``, ``float`` + (including ``NaN`` and ``Infinity``), ``str``, ``list``, + ``dict``, and ``datetime.datetime``. + context_id: The browsing context ID to run in. Defaults to the + driver's current window handle when a driver was provided. + + Returns: + The inner RemoteValue result dict, or raises WebDriverException on exception. + """ + import math as _math + import datetime as _datetime + from selenium.common.exceptions import WebDriverException as _WebDriverException + + def _serialize_arg(value): + """Serialise a Python value to a BiDi LocalValue dict.""" + if value is None: + return {"type": "null"} + if isinstance(value, bool): + return {"type": "boolean", "value": value} + if isinstance(value, _datetime.datetime): + return {"type": "date", "value": value.isoformat()} if isinstance(value, float): - if math.isnan(value): + if _math.isnan(value): return {"type": "number", "value": "NaN"} - elif math.isinf(value): - if value > 0: - return {"type": "number", "value": "Infinity"} - else: - return {"type": "number", "value": "-Infinity"} - elif value == 0.0 and math.copysign(1.0, value) < 0: - return {"type": "number", "value": "-0"} - - JS_MAX_SAFE_INTEGER = 9007199254740991 - if isinstance(value, int) and (value > JS_MAX_SAFE_INTEGER or value < -JS_MAX_SAFE_INTEGER): - return {"type": "bigint", "value": str(value)} - - return {"type": "number", "value": value} - - elif isinstance(value, str): - return {"type": "string", "value": value} - elif isinstance(value, datetime.datetime): - # Convert Python datetime to JavaScript Date (ISO 8601 format) - return {"type": "date", "value": value.isoformat() + "Z" if value.tzinfo is None else value.isoformat()} - elif isinstance(value, datetime.date): - # Convert Python date to JavaScript Date - dt = datetime.datetime.combine(value, datetime.time.min).replace(tzinfo=datetime.timezone.utc) - return {"type": "date", "value": dt.isoformat()} - elif isinstance(value, set): - return {"type": "set", "value": [self.__convert_to_local_value(item) for item in value]} - elif isinstance(value, (list, tuple)): - return {"type": "array", "value": [self.__convert_to_local_value(item) for item in value]} - elif isinstance(value, dict): - return { - "type": "object", - "value": [ - [self.__convert_to_local_value(k), self.__convert_to_local_value(v)] for k, v in value.items() - ], - } - else: - # For other types, convert to string - return {"type": "string", "value": str(value)} - - # low-level APIs for script module + if _math.isinf(value): + return {"type": "number", "value": "Infinity" if value > 0 else "-Infinity"} + return {"type": "number", "value": value} + if isinstance(value, int): + _MAX_SAFE_INT = 9007199254740991 + if abs(value) > _MAX_SAFE_INT: + return {"type": "bigint", "value": str(value)} + return {"type": "number", "value": value} + if isinstance(value, str): + return {"type": "string", "value": value} + if isinstance(value, list): + return {"type": "array", "value": [_serialize_arg(v) for v in value]} + if isinstance(value, dict): + return {"type": "object", "value": [[str(k), _serialize_arg(v)] for k, v in value.items()]} + return value + + if context_id is None and self._driver is not None: + try: + context_id = self._driver.current_window_handle + except Exception: + pass + target = {"context": context_id} if context_id else {} + serialized_args = [_serialize_arg(a) for a in args] if args else None + raw = self.call_function( + function_declaration=function_declaration, + await_promise=True, + target=target, + arguments=serialized_args, + ) + if isinstance(raw, dict): + if raw.get("type") == "exception": + exc = raw.get("exceptionDetails", {}) + msg = exc.get("text", str(exc)) if isinstance(exc, dict) else str(exc) + raise _WebDriverException(msg) + if raw.get("type") == "success": + return raw.get("result") + return raw def _add_preload_script( self, - function_declaration: str, - arguments: list[dict[str, Any]] | None = None, - contexts: list[str] | None = None, - user_contexts: list[str] | None = None, - sandbox: str | None = None, - ) -> str: - """Adds a preload script. + function_declaration, + arguments=None, + contexts=None, + user_contexts=None, + sandbox=None, + ): + """Add a preload script with validation. Args: - function_declaration: The function declaration to preload. - arguments: The arguments to pass to the function. - contexts: The browsing context IDs to apply the script to. - user_contexts: The user context IDs to apply the script to. - sandbox: The sandbox name to apply the script to. + function_declaration: The JS function to run on page load. + arguments: Optional list of BiDi arguments. + contexts: Optional list of browsing context IDs. + user_contexts: Optional list of user context IDs. + sandbox: Optional sandbox name. Returns: - str: The preload script ID. + script_id: The ID of the added preload script (str). Raises: - ValueError: If both contexts and user_contexts are provided. + ValueError: If both contexts and user_contexts are specified. """ if contexts is not None and user_contexts is not None: raise ValueError("Cannot specify both contexts and user_contexts") + result = self.add_preload_script( + function_declaration=function_declaration, + arguments=arguments, + contexts=contexts, + user_contexts=user_contexts, + sandbox=sandbox, + ) + if isinstance(result, dict): + return result.get("script") + return result + def _remove_preload_script(self, script_id): + """Remove a preload script by ID. - params: dict[str, Any] = {"functionDeclaration": function_declaration} - - if arguments is not None: - params["arguments"] = arguments - if contexts is not None: - params["contexts"] = contexts - if user_contexts is not None: - params["userContexts"] = user_contexts - if sandbox is not None: - params["sandbox"] = sandbox + Args: + script_id: The ID of the preload script to remove. + """ + return self.remove_preload_script(script=script_id) + def pin(self, function_declaration): + """Pin (add) a preload script that runs on every page load. - result = self.conn.execute(command_builder("script.addPreloadScript", params)) - return result["script"] + Args: + function_declaration: The JS function to execute on page load. - def _remove_preload_script(self, script_id: str) -> None: - """Removes a preload script. + Returns: + script_id: The ID of the pinned script (str). + """ + return self._add_preload_script(function_declaration) + def unpin(self, script_id): + """Unpin (remove) a previously pinned preload script. Args: - script_id: The preload script ID to remove. + script_id: The ID returned by pin(). """ - params = {"script": script_id} - self.conn.execute(command_builder("script.removePreloadScript", params)) - - def _disown(self, handles: list[str], target: dict) -> None: - """Disowns the given handles. + return self._remove_preload_script(script_id=script_id) + def _evaluate( + self, + expression, + target, + await_promise, + result_ownership=None, + serialization_options=None, + user_activation=None, + ): + """Evaluate a script expression and return a structured result. Args: - handles: The handles to disown. - target: The target realm or context. - """ - params = { - "handles": handles, - "target": target, - } - self.conn.execute(command_builder("script.disown", params)) + expression: The JavaScript expression to evaluate. + target: A dict like {"context": } or {"realm": }. + await_promise: Whether to await a returned promise. + result_ownership: Optional result ownership setting. + serialization_options: Optional serialization options dict. + user_activation: Optional user activation flag. + Returns: + An object with .realm, .result (dict or None), and .exception_details (or None). + """ + class _EvalResult: + def __init__(self2, realm, result, exception_details): + self2.realm = realm + self2.result = result + self2.exception_details = exception_details + + raw = self.evaluate( + expression=expression, + target=target, + await_promise=await_promise, + result_ownership=result_ownership, + serialization_options=serialization_options, + user_activation=user_activation, + ) + if isinstance(raw, dict): + realm = raw.get("realm") + if raw.get("type") == "exception": + exc = raw.get("exceptionDetails") + return _EvalResult(realm=realm, result=None, exception_details=exc) + return _EvalResult(realm=realm, result=raw.get("result"), exception_details=None) + return _EvalResult(realm=None, result=raw, exception_details=None) def _call_function( self, - function_declaration: str, - await_promise: bool, - target: dict, - arguments: list[dict] | None = None, - result_ownership: str | None = None, - serialization_options: dict | None = None, - this: dict | None = None, - user_activation: bool = False, - ) -> EvaluateResult: - """Calls a provided function with given arguments in a given realm. + function_declaration, + await_promise, + target, + arguments=None, + result_ownership=None, + this=None, + user_activation=None, + serialization_options=None, + ): + """Call a function and return a structured result. Args: - function_declaration: The function declaration to call. - await_promise: Whether to await promise resolution. - target: The target realm or context. - arguments: The arguments to pass to the function. - result_ownership: The result ownership type. - serialization_options: The serialization options. - this: The 'this' value for the function call. - user_activation: Whether to trigger user activation. + function_declaration: The JS function string. + await_promise: Whether to await the return value. + target: A dict like {"context": }. + arguments: Optional list of BiDi arguments. + result_ownership: Optional result ownership. + this: Optional 'this' binding. + user_activation: Optional user activation flag. + serialization_options: Optional serialization options dict. Returns: - EvaluateResult: The result of the function call. + An object with .result (dict or None) and .exception_details (or None). """ - params = { - "functionDeclaration": function_declaration, - "awaitPromise": await_promise, - "target": target, - "userActivation": user_activation, - } - - if arguments is not None: - params["arguments"] = arguments - if result_ownership is not None: - params["resultOwnership"] = result_ownership - if serialization_options is not None: - params["serializationOptions"] = serialization_options - if this is not None: - params["this"] = this - - result = self.conn.execute(command_builder("script.callFunction", params)) - return EvaluateResult.from_json(result) - - def _evaluate( - self, - expression: str, - target: dict, - await_promise: bool, - result_ownership: str | None = None, - serialization_options: dict | None = None, - user_activation: bool = False, - ) -> EvaluateResult: - """Evaluates a provided script in a given realm. + class _CallResult: + def __init__(self2, result, exception_details): + self2.result = result + self2.exception_details = exception_details + + raw = self.call_function( + function_declaration=function_declaration, + await_promise=await_promise, + target=target, + arguments=arguments, + result_ownership=result_ownership, + this=this, + user_activation=user_activation, + serialization_options=serialization_options, + ) + if isinstance(raw, dict): + if raw.get("type") == "exception": + exc = raw.get("exceptionDetails") + return _CallResult(result=None, exception_details=exc) + if raw.get("type") == "success": + return _CallResult(result=raw.get("result"), exception_details=None) + return _CallResult(result=raw, exception_details=None) + def _get_realms(self, context=None, type=None): + """Get all realms, optionally filtered by context and type. Args: - expression: The script expression to evaluate. - target: The target realm or context. - await_promise: Whether to await promise resolution. - result_ownership: The result ownership type. - serialization_options: The serialization options. - user_activation: Whether to trigger user activation. + context: Optional browsing context ID to filter by. + type: Optional realm type string to filter by (e.g. RealmType.WINDOW). Returns: - EvaluateResult: The result of the script evaluation. + List of realm info objects with .realm, .origin, .type, .context attributes. """ - params = { - "expression": expression, - "target": target, - "awaitPromise": await_promise, - "userActivation": user_activation, - } + class _RealmInfo: + def __init__(self2, realm, origin, type_, context): + self2.realm = realm + self2.origin = origin + self2.type = type_ + self2.context = context + + raw = self.get_realms(context=context, type=type) + realms_list = raw.get("realms", []) if isinstance(raw, dict) else [] + result = [] + for r in realms_list: + if isinstance(r, dict): + result.append(_RealmInfo( + realm=r.get("realm"), + origin=r.get("origin"), + type_=r.get("type"), + context=r.get("context"), + )) + return result + def _disown(self, handles, target): + """Disown handles in a browsing context. - if result_ownership is not None: - params["resultOwnership"] = result_ownership - if serialization_options is not None: - params["serializationOptions"] = serialization_options + Args: + handles: List of handle strings to disown. + target: A dict like {"context": }. + """ + return self.disown(handles=handles, target=target) + def _subscribe_log_entry(self, callback, entry_type_filter=None): + """Subscribe to log.entryAdded BiDi events with optional type filtering.""" + import threading as _threading + from selenium.webdriver.common.bidi.session import Session as _Session + from selenium.webdriver.common.bidi import log as _log_mod + + bidi_event = "log.entryAdded" + + if not hasattr(self, "_log_subscriptions"): + self._log_subscriptions = {} + self._log_lock = _threading.Lock() + + def _deserialize(params): + t = params.get("type") if isinstance(params, dict) else None + if t == "console": + cls = getattr(_log_mod, "ConsoleLogEntry", None) + if cls is not None and hasattr(cls, "from_json"): + try: + return cls.from_json(params) + except Exception: + pass + elif t == "javascript": + cls = getattr(_log_mod, "JavascriptLogEntry", None) + if cls is not None and hasattr(cls, "from_json"): + try: + return cls.from_json(params) + except Exception: + pass + return params + + def _wrapped(raw): + entry = _deserialize(raw) + if entry_type_filter is None: + callback(entry) + else: + t = getattr(entry, "type_", None) or ( + entry.get("type") if isinstance(entry, dict) else None + ) + if t == entry_type_filter: + callback(entry) + + class _BidiRef: + event_class = bidi_event + + def from_json(self2, p): + return p + + _wrapper = _BidiRef() + callback_id = self._conn.add_callback(_wrapper, _wrapped) + with self._log_lock: + if bidi_event not in self._log_subscriptions: + session = _Session(self._conn) + result = session.subscribe([bidi_event]) + sub_id = ( + result.get("subscription") if isinstance(result, dict) else None + ) + self._log_subscriptions[bidi_event] = { + "callbacks": [], + "subscription_id": sub_id, + } + self._log_subscriptions[bidi_event]["callbacks"].append(callback_id) + return callback_id + def _unsubscribe_log_entry(self, callback_id): + """Unsubscribe a log entry callback by ID.""" + from selenium.webdriver.common.bidi.session import Session as _Session + + bidi_event = "log.entryAdded" + if not hasattr(self, "_log_subscriptions"): + return + + class _BidiRef: + event_class = bidi_event + + def from_json(self2, p): + return p + + _wrapper = _BidiRef() + self._conn.remove_callback(_wrapper, callback_id) + with self._log_lock: + entry = self._log_subscriptions.get(bidi_event) + if entry and callback_id in entry["callbacks"]: + entry["callbacks"].remove(callback_id) + if entry is not None and not entry["callbacks"]: + session = _Session(self._conn) + sub_id = entry.get("subscription_id") + if sub_id: + session.unsubscribe(subscriptions=[sub_id]) + else: + session.unsubscribe(events=[bidi_event]) + del self._log_subscriptions[bidi_event] + def add_console_message_handler(self, callback: Callable) -> int: + """Add a handler for console log messages (log.entryAdded type=console). - result = self.conn.execute(command_builder("script.evaluate", params)) - return EvaluateResult.from_json(result) + Args: + callback: Function called with a ConsoleLogEntry on each console message. - def _get_realms( - self, - context: str | None = None, - type: str | None = None, - ) -> list[RealmInfo]: - """Returns a list of all realms, optionally filtered. + Returns: + callback_id for use with remove_console_message_handler. + """ + return self._subscribe_log_entry(callback, entry_type_filter="console") + def remove_console_message_handler(self, callback_id: int) -> None: + """Remove a console message handler by callback ID.""" + self._unsubscribe_log_entry(callback_id) + def add_javascript_error_handler(self, callback: Callable) -> int: + """Add a handler for JavaScript error log messages (log.entryAdded type=javascript). Args: - context: The browsing context ID to filter by. - type: The realm type to filter by. + callback: Function called with a JavascriptLogEntry on each JS error. Returns: - List[RealmInfo]: A list of realm information. + callback_id for use with remove_javascript_error_handler. """ - params = {} + return self._subscribe_log_entry(callback, entry_type_filter="javascript") + def remove_javascript_error_handler(self, callback_id: int) -> None: + """Remove a JavaScript error handler by callback ID.""" + self._unsubscribe_log_entry(callback_id) - if context is not None: - params["context"] = context - if type is not None: - params["type"] = type + def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: + """Add an event handler. - result = self.conn.execute(command_builder("script.getRealms", params)) - return [RealmInfo.from_json(realm) for realm in result["realms"]] - - def _subscribe_to_log_entries(self): - if not self.log_entry_subscribed: - session = Session(self.conn) - self.conn.execute(session.subscribe(LogEntryAdded.event_class)) - self.log_entry_subscribed = True + Args: + event: The event to subscribe to. + callback: The callback function to execute on event. + contexts: The context IDs to subscribe to (optional). - def _unsubscribe_from_log_entries(self): - if self.log_entry_subscribed and LogEntryAdded.event_class not in self.conn.callbacks: - session = Session(self.conn) - self.conn.execute(session.unsubscribe(LogEntryAdded.event_class)) - self.log_entry_subscribed = False + Returns: + The callback ID. + """ + return self._event_manager.add_event_handler(event, callback, contexts) - def _handle_log_entry(self, type, handler): - def _handle_log_entry(log_entry): - if log_entry.type_ == type: - handler(log_entry) + def remove_event_handler(self, event: str, callback_id: int) -> None: + """Remove an event handler. - return _handle_log_entry + Args: + event: The event to unsubscribe from. + callback_id: The callback ID. + """ + return self._event_manager.remove_event_handler(event, callback_id) + + def clear_event_handlers(self) -> None: + """Clear all event handlers.""" + return self._event_manager.clear_event_handlers() + +# Event Info Type Aliases +# Event: script.realmCreated +RealmCreated = globals().get('RealmInfo', dict) # Fallback to dict if type not defined + +# Event: script.realmDestroyed +RealmDestroyed = globals().get('RealmDestroyedParameters', dict) # Fallback to dict if type not defined + + +# Populate EVENT_CONFIGS with event configuration mappings +_globals = globals() +Script.EVENT_CONFIGS = { + "realm_created": EventConfig( + "realm_created", + "script.realmCreated", + _globals.get("RealmCreated", dict) if _globals.get("RealmCreated") else dict, + ), + "realm_destroyed": EventConfig( + "realm_destroyed", + "script.realmDestroyed", + _globals.get("RealmDestroyed", dict) if _globals.get("RealmDestroyed") else dict, + ), +} diff --git a/py/selenium/webdriver/common/bidi/session.py b/py/selenium/webdriver/common/bidi/session.py index 3481c2d77842d..177421eca5ee8 100644 --- a/py/selenium/webdriver/common/bidi/session.py +++ b/py/selenium/webdriver/common/bidi/session.py @@ -1,134 +1,246 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# WebDriver BiDi module: session +from __future__ import annotations - -from selenium.webdriver.common.bidi.common import command_builder +from typing import Any +from .common import command_builder +from dataclasses import dataclass +from dataclasses import field class UserPromptHandlerType: - """Represents the behavior of the user prompt handler.""" + """UserPromptHandlerType.""" ACCEPT = "accept" DISMISS = "dismiss" IGNORE = "ignore" - VALID_TYPES = {ACCEPT, DISMISS, IGNORE} + +@dataclass +class CapabilitiesRequest: + """CapabilitiesRequest.""" + + always_match: Any | None = None + first_match: list[Any] = field(default_factory=list) + + +@dataclass +class CapabilityRequest: + """CapabilityRequest.""" + + accept_insecure_certs: bool | None = None + browser_name: str | None = None + browser_version: str | None = None + platform_name: str | None = None + proxy: Any | None = None + unhandled_prompt_behavior: Any | None = None + + +@dataclass +class AutodetectProxyConfiguration: + """AutodetectProxyConfiguration.""" + + proxy_type: str = field(default="autodetect", init=False) + + +@dataclass +class DirectProxyConfiguration: + """DirectProxyConfiguration.""" + + proxy_type: str = field(default="direct", init=False) + + +@dataclass +class ManualProxyConfiguration: + """ManualProxyConfiguration.""" + + proxy_type: str = field(default="manual", init=False) + http_proxy: str | None = None + ssl_proxy: str | None = None + no_proxy: list[Any] = field(default_factory=list) + + +@dataclass +class SocksProxyConfiguration: + """SocksProxyConfiguration.""" + + socks_proxy: str | None = None + socks_version: Any | None = None + + +@dataclass +class PacProxyConfiguration: + """PacProxyConfiguration.""" + + proxy_type: str = field(default="pac", init=False) + proxy_autoconfig_url: str | None = None + + +@dataclass +class SystemProxyConfiguration: + """SystemProxyConfiguration.""" + + proxy_type: str = field(default="system", init=False) + + +@dataclass +class SubscribeParameters: + """SubscribeParameters.""" + + events: list[str] = field(default_factory=list) + contexts: list[Any] = field(default_factory=list) + user_contexts: list[Any] = field(default_factory=list) + + +@dataclass +class UnsubscribeByIDRequest: + """UnsubscribeByIDRequest.""" + + subscriptions: list[Any] = field(default_factory=list) + + +@dataclass +class UnsubscribeByAttributesRequest: + """UnsubscribeByAttributesRequest.""" + + events: list[str] = field(default_factory=list) + + +@dataclass +class StatusResult: + """StatusResult.""" + + ready: bool | None = None + message: str | None = None +@dataclass +class NewParameters: + """NewParameters.""" + + capabilities: Any | None = None + + +@dataclass +class NewResult: + """NewResult.""" + + session_id: str | None = None + accept_insecure_certs: bool | None = None + browser_name: str | None = None + browser_version: str | None = None + platform_name: str | None = None + set_window_rect: bool | None = None + user_agent: str | None = None + proxy: Any | None = None + unhandled_prompt_behavior: Any | None = None + web_socket_url: str | None = None + + +@dataclass +class SubscribeResult: + """SubscribeResult.""" + + subscription: Any | None = None + + +@dataclass class UserPromptHandler: - """Represents the configuration of the user prompt handler.""" + """UserPromptHandler.""" - def __init__( - self, - alert: str | None = None, - before_unload: str | None = None, - confirm: str | None = None, - default: str | None = None, - file: str | None = None, - prompt: str | None = None, - ): - """Initialize UserPromptHandler. - - Args: - alert: Handler type for alert prompts. - before_unload: Handler type for beforeUnload prompts. - confirm: Handler type for confirm prompts. - default: Default handler type for all prompts. - file: Handler type for file picker prompts. - prompt: Handler type for prompt dialogs. - - Raises: - ValueError: If any handler type is not valid. - """ - for field_name, value in [ - ("alert", alert), - ("before_unload", before_unload), - ("confirm", confirm), - ("default", default), - ("file", file), - ("prompt", prompt), - ]: - if value is not None and value not in UserPromptHandlerType.VALID_TYPES: - raise ValueError( - f"Invalid {field_name} handler type: {value}. Must be one of {UserPromptHandlerType.VALID_TYPES}" - ) - - self.alert = alert - self.before_unload = before_unload - self.confirm = confirm - self.default = default - self.file = file - self.prompt = prompt - - def to_dict(self) -> dict[str, str]: - """Convert the UserPromptHandler to a dictionary for BiDi protocol. - - Returns: - Dictionary representation suitable for BiDi protocol. - """ - field_mapping = { - "alert": "alert", - "before_unload": "beforeUnload", - "confirm": "confirm", - "default": "default", - "file": "file", - "prompt": "prompt", - } + alert: Any | None = None + before_unload: Any | None = None + confirm: Any | None = None + default: Any | None = None + file: Any | None = None + prompt: Any | None = None + def to_bidi_dict(self) -> dict: + """Convert to BiDi protocol dict with camelCase keys.""" result = {} - for attr_name, dict_key in field_mapping.items(): - value = getattr(self, attr_name) - if value is not None: - result[dict_key] = value + if self.alert is not None: + result["alert"] = self.alert + if self.before_unload is not None: + result["beforeUnload"] = self.before_unload + if self.confirm is not None: + result["confirm"] = self.confirm + if self.default is not None: + result["default"] = self.default + if self.file is not None: + result["file"] = self.file + if self.prompt is not None: + result["prompt"] = self.prompt return result - class Session: - def __init__(self, conn): - self.conn = conn + """WebDriver BiDi session module.""" + + def __init__(self, conn) -> None: + self._conn = conn - def subscribe(self, *events, browsing_contexts=None): + def status(self): + """Execute session.status.""" params = { - "events": events, } - if browsing_contexts is None: - browsing_contexts = [] - if browsing_contexts: - params["browsingContexts"] = browsing_contexts - return command_builder("session.subscribe", params) + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("session.status", params) + result = self._conn.execute(cmd) + return result + + def new(self, capabilities: Any | None = None): + """Execute session.new.""" + if capabilities is None: + raise TypeError("new() missing required argument: {{snake_param!r}}") + + params = { + "capabilities": capabilities, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("session.new", params) + result = self._conn.execute(cmd) + return result - def unsubscribe(self, *events, browsing_contexts=None): + def end(self): + """Execute session.end.""" params = { - "events": events, } - if browsing_contexts is None: - browsing_contexts = [] - if browsing_contexts: - params["browsingContexts"] = browsing_contexts - return command_builder("session.unsubscribe", params) + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("session.end", params) + result = self._conn.execute(cmd) + return result - def status(self): - """The session.status command returns information about the remote end's readiness. + def subscribe( + self, + events: list[Any] | None = None, + contexts: list[Any] | None = None, + user_contexts: list[Any] | None = None, + ): + """Execute session.subscribe.""" + if events is None: + raise TypeError("subscribe() missing required argument: {{snake_param!r}}") - Returns information about the remote end's readiness to create new sessions - and may include implementation-specific metadata. + params = { + "events": events, + "contexts": contexts, + "userContexts": user_contexts, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("session.subscribe", params) + result = self._conn.execute(cmd) + return result + + def unsubscribe(self, events: list[Any] | None = None, subscriptions: list[Any] | None = None): + """Execute session.unsubscribe.""" + params = { + "events": events, + "subscriptions": subscriptions, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("session.unsubscribe", params) + result = self._conn.execute(cmd) + return result - Returns: - Dictionary containing the ready state (bool), message (str) and metadata. - """ - cmd = command_builder("session.status", {}) - return self.conn.execute(cmd) diff --git a/py/selenium/webdriver/common/bidi/storage.py b/py/selenium/webdriver/common/bidi/storage.py index 992ede07f4100..fef35106c33b0 100644 --- a/py/selenium/webdriver/common/bidi/storage.py +++ b/py/selenium/webdriver/common/bidi/storage.py @@ -1,150 +1,151 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# WebDriver BiDi module: storage from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import Any +from .common import command_builder +from dataclasses import dataclass +from dataclasses import field -from selenium.webdriver.common.bidi.common import command_builder -if TYPE_CHECKING: - from selenium.webdriver.remote.websocket_connection import WebSocketConnection +@dataclass +class PartitionKey: + """PartitionKey.""" + user_context: str | None = None + source_origin: str | None = None -class SameSite: - """Represents the possible same site values for cookies.""" - STRICT = "strict" - LAX = "lax" - NONE = "none" - DEFAULT = "default" +@dataclass +class GetCookiesParameters: + """GetCookiesParameters.""" + + filter: Any | None = None + partition: Any | None = None + + +@dataclass +class GetCookiesResult: + """GetCookiesResult.""" + + cookies: list[Any] = field(default_factory=list) + partition_key: Any | None = None + + +@dataclass +class SetCookieParameters: + """SetCookieParameters.""" + + cookie: Any | None = None + partition: Any | None = None + + +@dataclass +class SetCookieResult: + """SetCookieResult.""" + + partition_key: Any | None = None + + +@dataclass +class DeleteCookiesParameters: + """DeleteCookiesParameters.""" + + filter: Any | None = None + partition: Any | None = None + + +@dataclass +class DeleteCookiesResult: + """DeleteCookiesResult.""" + + partition_key: Any | None = None class BytesValue: - """Represents a bytes value.""" + """A string or base64-encoded bytes value used in cookie operations. + + This corresponds to network.BytesValue in the WebDriver BiDi specification, + wrapping either a plain string or a base64-encoded binary value. + """ - TYPE_BASE64 = "base64" TYPE_STRING = "string" + TYPE_BASE64 = "base64" - def __init__(self, type: str, value: str): + def __init__(self, type: Any | None, value: Any | None) -> None: self.type = type self.value = value - def to_dict(self) -> dict[str, str]: - """Converts the BytesValue to a dictionary. - - Returns: - A dictionary representation of the BytesValue. - """ + def to_bidi_dict(self) -> dict: return {"type": self.type, "value": self.value} +class SameSite: + """SameSite cookie attribute values.""" -class Cookie: - """Represents a cookie.""" - - def __init__( - self, - name: str, - value: BytesValue, - domain: str, - path: str | None = None, - size: int | None = None, - http_only: bool | None = None, - secure: bool | None = None, - same_site: str | None = None, - expiry: int | None = None, - ): - self.name = name - self.value = value - self.domain = domain - self.path = path - self.size = size - self.http_only = http_only - self.secure = secure - self.same_site = same_site - self.expiry = expiry + STRICT = "strict" + LAX = "lax" + NONE = "none" + DEFAULT = "default" + +@dataclass +class StorageCookie: + """A cookie object returned by storage.getCookies.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + path: str | None = None + size: Any | None = None + http_only: bool | None = None + secure: bool | None = None + same_site: Any | None = None + expiry: Any | None = None @classmethod - def from_dict(cls, data: dict[str, Any]) -> Cookie: - """Creates a Cookie instance from a dictionary. - - Args: - data: A dictionary containing the cookie information. - - Returns: - A new instance of Cookie. - """ - # Validation for empty strings - name = data.get("name") - if not name: - raise ValueError("name is required and cannot be empty") - domain = data.get("domain") - if not domain: - raise ValueError("domain is required and cannot be empty") - - value = BytesValue(data.get("value", {}).get("type"), data.get("value", {}).get("value")) + def from_bidi_dict(cls, raw: dict) -> "StorageCookie": + """Deserialize a wire-level cookie dict to a StorageCookie.""" + value_raw = raw.get("value") + if isinstance(value_raw, dict): + value: Any = BytesValue(value_raw.get("type"), value_raw.get("value")) + else: + value = value_raw return cls( - name=str(name), + name=raw.get("name"), value=value, - domain=str(domain), - path=data.get("path"), - size=data.get("size"), - http_only=data.get("httpOnly"), - secure=data.get("secure"), - same_site=data.get("sameSite"), - expiry=data.get("expiry"), + domain=raw.get("domain"), + path=raw.get("path"), + size=raw.get("size"), + http_only=raw.get("httpOnly"), + secure=raw.get("secure"), + same_site=raw.get("sameSite"), + expiry=raw.get("expiry"), ) - +@dataclass class CookieFilter: - """Represents a filter for cookies.""" - - def __init__( - self, - name: str | None = None, - value: BytesValue | None = None, - domain: str | None = None, - path: str | None = None, - size: int | None = None, - http_only: bool | None = None, - secure: bool | None = None, - same_site: str | None = None, - expiry: int | None = None, - ): - self.name = name - self.value = value - self.domain = domain - self.path = path - self.size = size - self.http_only = http_only - self.secure = secure - self.same_site = same_site - self.expiry = expiry - - def to_dict(self) -> dict[str, Any]: - """Converts the CookieFilter to a dictionary. - - Returns: - A dictionary representation of the CookieFilter. - """ - result: dict[str, Any] = {} + """CookieFilter.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + path: str | None = None + size: Any | None = None + http_only: bool | None = None + secure: bool | None = None + same_site: Any | None = None + expiry: Any | None = None + + def to_bidi_dict(self) -> dict: + """Serialize to the BiDi wire-protocol dict.""" + result: dict = {} if self.name is not None: result["name"] = self.name if self.value is not None: - result["value"] = self.value.to_dict() + result["value"] = self.value.to_bidi_dict() if hasattr(self.value, "to_bidi_dict") else self.value if self.domain is not None: result["domain"] = self.domain if self.path is not None: @@ -161,103 +162,28 @@ def to_dict(self) -> dict[str, Any]: result["expiry"] = self.expiry return result - -class PartitionKey: - """Represents a storage partition key.""" - - def __init__(self, user_context: str | None = None, source_origin: str | None = None): - self.user_context = user_context - self.source_origin = source_origin - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> PartitionKey: - """Creates a PartitionKey instance from a dictionary. - - Args: - data: A dictionary containing the partition key information. - - Returns: - A new instance of PartitionKey. - """ - return cls( - user_context=data.get("userContext"), - source_origin=data.get("sourceOrigin"), - ) - - -class BrowsingContextPartitionDescriptor: - """Represents a browsing context partition descriptor.""" - - def __init__(self, context: str): - self.type = "context" - self.context = context - - def to_dict(self) -> dict[str, str]: - """Converts the BrowsingContextPartitionDescriptor to a dictionary. - - Returns: - Dict: A dictionary representation of the BrowsingContextPartitionDescriptor. - """ - return {"type": self.type, "context": self.context} - - -class StorageKeyPartitionDescriptor: - """Represents a storage key partition descriptor.""" - - def __init__(self, user_context: str | None = None, source_origin: str | None = None): - self.type = "storageKey" - self.user_context = user_context - self.source_origin = source_origin - - def to_dict(self) -> dict[str, str]: - """Converts the StorageKeyPartitionDescriptor to a dictionary. - - Returns: - Dict: A dictionary representation of the StorageKeyPartitionDescriptor. - """ - result = {"type": self.type} - if self.user_context is not None: - result["userContext"] = self.user_context - if self.source_origin is not None: - result["sourceOrigin"] = self.source_origin - return result - - +@dataclass class PartialCookie: - """Represents a partial cookie for setting.""" - - def __init__( - self, - name: str, - value: BytesValue, - domain: str, - path: str | None = None, - http_only: bool | None = None, - secure: bool | None = None, - same_site: str | None = None, - expiry: int | None = None, - ): - self.name = name - self.value = value - self.domain = domain - self.path = path - self.http_only = http_only - self.secure = secure - self.same_site = same_site - self.expiry = expiry - - def to_dict(self) -> dict[str, Any]: - """Converts the PartialCookie to a dictionary. - - Returns: - ------- - Dict: A dictionary representation of the PartialCookie. - """ - result: dict[str, Any] = { - "name": self.name, - "value": self.value.to_dict(), - "domain": self.domain, - } + """PartialCookie.""" + + name: str | None = None + value: Any | None = None + domain: str | None = None + path: str | None = None + http_only: bool | None = None + secure: bool | None = None + same_site: Any | None = None + expiry: Any | None = None + + def to_bidi_dict(self) -> dict: + """Serialize to the BiDi wire-protocol dict.""" + result: dict = {} + if self.name is not None: + result["name"] = self.name + if self.value is not None: + result["value"] = self.value.to_bidi_dict() if hasattr(self.value, "to_bidi_dict") else self.value + if self.domain is not None: + result["domain"] = self.domain if self.path is not None: result["path"] = self.path if self.http_only is not None: @@ -270,144 +196,99 @@ def to_dict(self) -> dict[str, Any]: result["expiry"] = self.expiry return result +class BrowsingContextPartitionDescriptor: + """BrowsingContextPartitionDescriptor. -class GetCookiesResult: - """Represents the result of a getCookies command.""" - - def __init__(self, cookies: list[Cookie], partition_key: PartitionKey): - self.cookies = cookies - self.partition_key = partition_key - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> GetCookiesResult: - """Creates a GetCookiesResult instance from a dictionary. - - Args: - data: A dictionary containing the get cookies result information. - - Returns: - A new instance of GetCookiesResult. - """ - cookies = [Cookie.from_dict(cookie) for cookie in data.get("cookies", [])] - partition_key = PartitionKey.from_dict(data.get("partitionKey", {})) - return cls(cookies=cookies, partition_key=partition_key) - - -class SetCookieResult: - """Represents the result of a setCookie command.""" - - def __init__(self, partition_key: PartitionKey): - self.partition_key = partition_key - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> SetCookieResult: - """Creates a SetCookieResult instance from a dictionary. - - Args: - data: A dictionary containing the set cookie result information. - - Returns: - A new instance of SetCookieResult. - """ - partition_key = PartitionKey.from_dict(data.get("partitionKey", {})) - return cls(partition_key=partition_key) - - -class DeleteCookiesResult: - """Represents the result of a deleteCookies command.""" + The first positional argument is *context* (a browsing-context ID / window + handle), mirroring how the class is used throughout the test suite: + ``BrowsingContextPartitionDescriptor(driver.current_window_handle)``. + """ - def __init__(self, partition_key: PartitionKey): - self.partition_key = partition_key + def __init__(self, context: Any = None, type: str = "context") -> None: + self.context = context + self.type = type - @classmethod - def from_dict(cls, data: dict[str, Any]) -> DeleteCookiesResult: - """Creates a DeleteCookiesResult instance from a dictionary. + def to_bidi_dict(self) -> dict: + return {"type": "context", "context": self.context} - Args: - data: A dictionary containing the delete cookies result information. +@dataclass +class StorageKeyPartitionDescriptor: + """StorageKeyPartitionDescriptor.""" - Returns: - A new instance of DeleteCookiesResult. - """ - partition_key = PartitionKey.from_dict(data.get("partitionKey", {})) - return cls(partition_key=partition_key) + type: Any | None = "storageKey" + user_context: str | None = None + source_origin: str | None = None + def to_bidi_dict(self) -> dict: + """Serialize to the BiDi wire-protocol dict.""" + result: dict = {"type": "storageKey"} + if self.user_context is not None: + result["userContext"] = self.user_context + if self.source_origin is not None: + result["sourceOrigin"] = self.source_origin + return result class Storage: - """BiDi implementation of the storage module.""" - - def __init__(self, conn: WebSocketConnection) -> None: - self.conn = conn - - def get_cookies( - self, - filter: CookieFilter | None = None, - partition: BrowsingContextPartitionDescriptor | StorageKeyPartitionDescriptor | None = None, - ) -> GetCookiesResult: - """Gets cookies matching the specified filter. - - Args: - filter: Optional filter to specify which cookies to retrieve. - partition: Optional partition key to limit the scope of the operation. - - Returns: - A GetCookiesResult containing the cookies and partition key. - - Example: - result = await storage.get_cookies( - filter=CookieFilter(name="sessionId"), - partition=PartitionKey(...) + """WebDriver BiDi storage module.""" + + def __init__(self, conn) -> None: + self._conn = conn + + def get_cookies(self, filter=None, partition=None): + """Execute storage.getCookies and return a GetCookiesResult.""" + if filter and hasattr(filter, "to_bidi_dict"): + filter = filter.to_bidi_dict() + if partition and hasattr(partition, "to_bidi_dict"): + partition = partition.to_bidi_dict() + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.getCookies", params) + result = self._conn.execute(cmd) + if result and "cookies" in result: + cookies = [ + StorageCookie.from_bidi_dict(c) + for c in result.get("cookies", []) + if isinstance(c, dict) + ] + pk_raw = result.get("partitionKey") + pk = ( + PartitionKey( + user_context=pk_raw.get("userContext"), + source_origin=pk_raw.get("sourceOrigin"), + ) + if isinstance(pk_raw, dict) + else None ) - """ - params = {} - if filter is not None: - params["filter"] = filter.to_dict() - if partition is not None: - params["partition"] = partition.to_dict() - - result = self.conn.execute(command_builder("storage.getCookies", params)) - return GetCookiesResult.from_dict(result) - - def set_cookie( - self, - cookie: PartialCookie, - partition: BrowsingContextPartitionDescriptor | StorageKeyPartitionDescriptor | None = None, - ) -> SetCookieResult: - """Sets a cookie in the browser. - - Args: - cookie: The cookie to set. - partition: Optional partition descriptor. - - Returns: - The result of the set cookie command. - """ - params = {"cookie": cookie.to_dict()} - if partition is not None: - params["partition"] = partition.to_dict() - - result = self.conn.execute(command_builder("storage.setCookie", params)) - return SetCookieResult.from_dict(result) - - def delete_cookies( - self, - filter: CookieFilter | None = None, - partition: BrowsingContextPartitionDescriptor | StorageKeyPartitionDescriptor | None = None, - ) -> DeleteCookiesResult: - """Deletes cookies that match the given parameters. - - Args: - filter: Optional filter to match cookies to delete. - partition: Optional partition descriptor. - - Returns: - The result of the delete cookies command. - """ - params = {} - if filter is not None: - params["filter"] = filter.to_dict() - if partition is not None: - params["partition"] = partition.to_dict() - - result = self.conn.execute(command_builder("storage.deleteCookies", params)) - return DeleteCookiesResult.from_dict(result) + return GetCookiesResult(cookies=cookies, partition_key=pk) + return GetCookiesResult(cookies=[], partition_key=None) + def set_cookie(self, cookie=None, partition=None): + """Execute storage.setCookie.""" + if cookie and hasattr(cookie, "to_bidi_dict"): + cookie = cookie.to_bidi_dict() + if partition and hasattr(partition, "to_bidi_dict"): + partition = partition.to_bidi_dict() + params = { + "cookie": cookie, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.setCookie", params) + result = self._conn.execute(cmd) + return result + def delete_cookies(self, filter=None, partition=None): + """Execute storage.deleteCookies.""" + if filter and hasattr(filter, "to_bidi_dict"): + filter = filter.to_bidi_dict() + if partition and hasattr(partition, "to_bidi_dict"): + partition = partition.to_bidi_dict() + params = { + "filter": filter, + "partition": partition, + } + params = {k: v for k, v in params.items() if v is not None} + cmd = command_builder("storage.deleteCookies", params) + result = self._conn.execute(cmd) + return result diff --git a/py/selenium/webdriver/common/bidi/webextension.py b/py/selenium/webdriver/common/bidi/webextension.py index 7609a04f3b3a4..1c5b342c070d5 100644 --- a/py/selenium/webdriver/common/bidi/webextension.py +++ b/py/selenium/webdriver/common/bidi/webextension.py @@ -1,78 +1,129 @@ -# Licensed to the Software Freedom Conservancy (SFC) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The SFC licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at +# DO NOT EDIT THIS FILE! # -# http://www.apache.org/licenses/LICENSE-2.0 +# This file is generated from the WebDriver BiDi specification. If you need to make +# changes, edit the generator and regenerate all of the modules. # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# WebDriver BiDi module: webExtension +from __future__ import annotations +from typing import Any +from .common import command_builder +from dataclasses import dataclass +from dataclasses import field -from selenium.common.exceptions import WebDriverException -from selenium.webdriver.common.bidi.common import command_builder + +@dataclass +class InstallParameters: + """InstallParameters.""" + + extension_data: Any | None = None + + +@dataclass +class ExtensionPath: + """ExtensionPath.""" + + type: str = field(default="path", init=False) + path: str | None = None + + +@dataclass +class ExtensionArchivePath: + """ExtensionArchivePath.""" + + type: str = field(default="archivePath", init=False) + path: str | None = None + + +@dataclass +class ExtensionBase64Encoded: + """ExtensionBase64Encoded.""" + + type: str = field(default="base64", init=False) + value: str | None = None + + +@dataclass +class InstallResult: + """InstallResult.""" + + extension: Any | None = None + + +@dataclass +class UninstallParameters: + """UninstallParameters.""" + + extension: Any | None = None class WebExtension: - """BiDi implementation of the webExtension module.""" + """WebDriver BiDi webExtension module.""" - def __init__(self, conn): - self.conn = conn + def __init__(self, conn) -> None: + self._conn = conn - def install(self, path=None, archive_path=None, base64_value=None) -> dict: - """Installs a web extension in the remote end. + def install( + self, + path: str | None = None, + archive_path: str | None = None, + base64_value: str | None = None, + ): + """Install a web extension. - You must provide exactly one of the parameters. + Exactly one of the three keyword arguments must be provided. Args: - path: Path to an extension directory. - archive_path: Path to an extension archive file. - base64_value: Base64 encoded string of the extension archive. + path: Directory path to an unpacked extension (also accepted for + signed ``.xpi`` / ``.crx`` archive files on Firefox). + archive_path: File-system path to a packed extension archive. + base64_value: Base64-encoded extension archive string. Returns: - A dictionary containing the extension ID. - """ - if sum(x is not None for x in (path, archive_path, base64_value)) != 1: - raise ValueError("Exactly one of path, archive_path, or base64_value must be provided") + The raw result dict from the BiDi ``webExtension.install`` command + (contains at least an ``"extension"`` key with the extension ID). + Raises: + ValueError: If more than one, or none, of the arguments is provided. + """ + provided = [ + k for k, v in { + "path": path, "archive_path": archive_path, "base64_value": base64_value, + }.items() if v is not None + ] + if len(provided) != 1: + raise ValueError( + f"Exactly one of path, archive_path, or base64_value must be provided; got: {provided}" + ) if path is not None: extension_data = {"type": "path", "path": path} elif archive_path is not None: extension_data = {"type": "archivePath", "path": archive_path} - elif base64_value is not None: + else: + assert base64_value is not None extension_data = {"type": "base64", "value": base64_value} - params = {"extensionData": extension_data} - - try: - result = self.conn.execute(command_builder("webExtension.install", params)) - return result - except WebDriverException as e: - if "Method not available" in str(e): - raise WebDriverException( - f"{e!s}. If you are using Chrome or Edge, add '--enable-unsafe-extension-debugging' " - "and '--remote-debugging-pipe' arguments or set options.enable_webextensions = True" - ) from e - raise - - def uninstall(self, extension_id_or_result: str | dict) -> None: - """Uninstalls a web extension from the remote end. + cmd = command_builder("webExtension.install", params) + return self._conn.execute(cmd) + def uninstall(self, extension: str | dict): + """Uninstall a web extension. Args: - extension_id_or_result: Either the extension ID as a string or the result dictionary - from a previous install() call containing the extension ID. + extension: Either the extension ID string returned by ``install``, + or the full result dict returned by ``install`` (the + ``"extension"`` value is extracted automatically). + + Raises: + ValueError: If extension is not provided or is None. """ - if isinstance(extension_id_or_result, dict): - extension_id = extension_id_or_result.get("extension") + if isinstance(extension, dict): + extension_id: Any = extension.get("extension") else: - extension_id = extension_id_or_result + extension_id = extension + + if extension_id is None: + raise ValueError("extension parameter is required") params = {"extension": extension_id} - self.conn.execute(command_builder("webExtension.uninstall", params)) + cmd = command_builder("webExtension.uninstall", params) + return self._conn.execute(cmd) diff --git a/py/selenium/webdriver/common/by.py b/py/selenium/webdriver/common/by.py index d2a10ac70a7c6..9cc0ac1b1864a 100644 --- a/py/selenium/webdriver/common/by.py +++ b/py/selenium/webdriver/common/by.py @@ -16,9 +16,20 @@ # under the License. """The By implementation.""" +from __future__ import annotations + from typing import Literal -ByType = Literal["id", "xpath", "link text", "partial link text", "name", "tag name", "class name", "css selector"] +ByType = Literal[ + "id", + "xpath", + "link text", + "partial link text", + "name", + "tag name", + "class name", + "css selector", +] class By: diff --git a/py/selenium/webdriver/common/proxy.py b/py/selenium/webdriver/common/proxy.py index 89172d1122c36..28de19afa5742 100644 --- a/py/selenium/webdriver/common/proxy.py +++ b/py/selenium/webdriver/common/proxy.py @@ -17,6 +17,8 @@ """The Proxy implementation.""" +from __future__ import annotations + class ProxyTypeFactory: """Factory for proxy types.""" @@ -33,13 +35,23 @@ class ProxyType: profile preference, 'string' is id of proxy type. """ - DIRECT = ProxyTypeFactory.make(0, "DIRECT") # Direct connection, no proxy (default on Windows). - MANUAL = ProxyTypeFactory.make(1, "MANUAL") # Manual proxy settings (e.g., for httpProxy). + DIRECT = ProxyTypeFactory.make( + 0, "DIRECT" + ) # Direct connection, no proxy (default on Windows). + MANUAL = ProxyTypeFactory.make( + 1, "MANUAL" + ) # Manual proxy settings (e.g., for httpProxy). PAC = ProxyTypeFactory.make(2, "PAC") # Proxy autoconfiguration from URL. RESERVED_1 = ProxyTypeFactory.make(3, "RESERVED1") # Never used. - AUTODETECT = ProxyTypeFactory.make(4, "AUTODETECT") # Proxy autodetection (presumably with WPAD). - SYSTEM = ProxyTypeFactory.make(5, "SYSTEM") # Use system settings (default on Linux). - UNSPECIFIED = ProxyTypeFactory.make(6, "UNSPECIFIED") # Not initialized (for internal use). + AUTODETECT = ProxyTypeFactory.make( + 4, "AUTODETECT" + ) # Proxy autodetection (presumably with WPAD). + SYSTEM = ProxyTypeFactory.make( + 5, "SYSTEM" + ) # Use system settings (default on Linux). + UNSPECIFIED = ProxyTypeFactory.make( + 6, "UNSPECIFIED" + ) # Not initialized (for internal use). @classmethod def load(cls, value): @@ -48,7 +60,11 @@ def load(cls, value): value = str(value).upper() for attr in dir(cls): attr_value = getattr(cls, attr) - if isinstance(attr_value, dict) and "string" in attr_value and attr_value["string"] == value: + if ( + isinstance(attr_value, dict) + and "string" in attr_value + and attr_value["string"] == value + ): return attr_value raise Exception(f"No proxy type is found for {value}") @@ -203,13 +219,17 @@ def to_bidi_dict(self) -> dict: if self.noProxy: # Convert comma-separated string to list if isinstance(self.noProxy, str): - result["noProxy"] = [host.strip() for host in self.noProxy.split(",") if host.strip()] + result["noProxy"] = [ + host.strip() for host in self.noProxy.split(",") if host.strip() + ] elif isinstance(self.noProxy, list): if not all(isinstance(h, str) for h in self.noProxy): raise TypeError("no_proxy list must contain only strings") result["noProxy"] = self.noProxy else: - raise TypeError("no_proxy must be a comma-separated string or a list of strings") + raise TypeError( + "no_proxy must be a comma-separated string or a list of strings" + ) elif proxy_type == "pac": if self.proxyAutoconfigUrl: diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index 573f967f4a50c..ed05f8c1585f8 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -20,6 +20,7 @@ import base64 import contextlib import copy +import inspect import os import pkgutil import tempfile @@ -113,7 +114,9 @@ def get_remote_connection( client_config: ClientConfig | None = None, ) -> RemoteConnection: if isinstance(command_executor, str): - client_config = client_config or ClientConfig(remote_server_addr=command_executor) + client_config = client_config or ClientConfig( + remote_server_addr=command_executor + ) client_config.remote_server_addr = command_executor command_executor = RemoteConnection(client_config=client_config) @@ -394,9 +397,13 @@ def create_web_element(self, element_id: str) -> WebElement: def _unwrap_value(self, value): if isinstance(value, dict): if "element-6066-11e4-a52e-4f735466cecf" in value: - return self.create_web_element(value["element-6066-11e4-a52e-4f735466cecf"]) + return self.create_web_element( + value["element-6066-11e4-a52e-4f735466cecf"] + ) if "shadow-6066-11e4-a52e-4f735466cecf" in value: - return self._shadowroot_cls(self, value["shadow-6066-11e4-a52e-4f735466cecf"]) + return self._shadowroot_cls( + self, value["shadow-6066-11e4-a52e-4f735466cecf"] + ) for key, val in value.items(): value[key] = self._unwrap_value(val) return value @@ -422,18 +429,33 @@ def execute_cdp_cmd(self, cmd: str, cmd_args: dict): Example: `driver.execute_cdp_cmd("Network.getResponseBody", {"requestId": requestId})` """ - return self.execute("executeCdpCommand", {"cmd": cmd, "params": cmd_args})["value"] + return self.execute("executeCdpCommand", {"cmd": cmd, "params": cmd_args})[ + "value" + ] - def execute(self, driver_command: str, params: dict[str, Any] | None = None) -> dict[str, Any]: + def execute( + self, driver_command: str, params: dict[str, Any] | None = None + ) -> dict[str, Any]: """Sends a command to be executed by a command.CommandExecutor. Args: - driver_command: The name of the command to execute as a string. + driver_command: The name of the command to execute as a string. Can also be a generator + for BiDi protocol commands. params: A dictionary of named parameters to send with the command. Returns: The command's JSON response loaded into a dictionary object. """ + # Handle BiDi generator commands + if inspect.isgenerator(driver_command): + # BiDi command: route through the WebSocket connection, not the + # HTTP RemoteConnection which only accepts (command, params) pairs. + if not self._websocket_connection: + self._start_bidi() + assert self._websocket_connection is not None + return self._websocket_connection.execute(driver_command) + + # Legacy WebDriver command: handle normally params = self._wrap_value(params) if self.session_id: @@ -442,7 +464,9 @@ def execute(self, driver_command: str, params: dict[str, Any] | None = None) -> elif "sessionId" not in params: params["sessionId"] = self.session_id - response = cast(RemoteConnection, self.command_executor).execute(driver_command, params) + response = cast(RemoteConnection, self.command_executor).execute( + driver_command, params + ) if response: self.error_handler.check_response(response) @@ -498,7 +522,9 @@ def unpin(self, script_key: ScriptKey) -> None: try: self.pinned_scripts.pop(script_key.id) except KeyError: - raise KeyError(f"No script with key: {script_key} existed in {self.pinned_scripts}") from None + raise KeyError( + f"No script with key: {script_key} existed in {self.pinned_scripts}" + ) from None def get_pinned_scripts(self) -> list[str]: """Return a list of all pinned scripts. @@ -531,7 +557,9 @@ def execute_script(self, script: str, *args) -> Any: converted_args = list(args) command = Command.W3C_EXECUTE_SCRIPT - return self.execute(command, {"script": script, "args": converted_args})["value"] + return self.execute(command, {"script": script, "args": converted_args})[ + "value" + ] def execute_async_script(self, script: str, *args) -> Any: """Asynchronously Executes JavaScript in the current window/frame. @@ -550,7 +578,9 @@ def execute_async_script(self, script: str, *args) -> Any: converted_args = list(args) command = Command.W3C_EXECUTE_SCRIPT_ASYNC - return self.execute(command, {"script": script, "args": converted_args})["value"] + return self.execute(command, {"script": script, "args": converted_args})[ + "value" + ] @property def current_url(self) -> str: @@ -724,7 +754,9 @@ def implicitly_wait(self, time_to_wait: float) -> None: Example: `driver.implicitly_wait(30)` """ - self.execute(Command.SET_TIMEOUTS, {"implicit": int(float(time_to_wait) * 1000)}) + self.execute( + Command.SET_TIMEOUTS, {"implicit": int(float(time_to_wait) * 1000)} + ) def set_script_timeout(self, time_to_wait: float) -> None: """Set the timeout for asynchronous script execution. @@ -753,9 +785,14 @@ def set_page_load_timeout(self, time_to_wait: float) -> None: `driver.set_page_load_timeout(30)` """ try: - self.execute(Command.SET_TIMEOUTS, {"pageLoad": int(float(time_to_wait) * 1000)}) + self.execute( + Command.SET_TIMEOUTS, {"pageLoad": int(float(time_to_wait) * 1000)} + ) except WebDriverException: - self.execute(Command.SET_TIMEOUTS, {"ms": float(time_to_wait) * 1000, "type": "page load"}) + self.execute( + Command.SET_TIMEOUTS, + {"ms": float(time_to_wait) * 1000, "type": "page load"}, + ) @property def timeouts(self) -> Timeouts: @@ -791,7 +828,9 @@ def timeouts(self, timeouts) -> None: """ _ = self.execute(Command.SET_TIMEOUTS, timeouts._to_json())["value"] - def find_element(self, by: str | RelativeBy = By.ID, value: str | None = None) -> WebElement: + def find_element( + self, by: str | RelativeBy = By.ID, value: str | None = None + ) -> WebElement: """Find an element given a By strategy and locator. Args: @@ -812,12 +851,18 @@ def find_element(self, by: str | RelativeBy = By.ID, value: str | None = None) - if isinstance(by, RelativeBy): elements = self.find_elements(by=by, value=value) if not elements: - raise NoSuchElementException(f"Cannot locate relative element with: {by.root}") + raise NoSuchElementException( + f"Cannot locate relative element with: {by.root}" + ) return elements[0] - return self.execute(Command.FIND_ELEMENT, {"using": by, "value": value})["value"] + return self.execute(Command.FIND_ELEMENT, {"using": by, "value": value})[ + "value" + ] - def find_elements(self, by: str | RelativeBy = By.ID, value: str | None = None) -> list[WebElement]: + def find_elements( + self, by: str | RelativeBy = By.ID, value: str | None = None + ) -> list[WebElement]: """Find elements given a By strategy and locator. Args: @@ -839,14 +884,21 @@ def find_elements(self, by: str | RelativeBy = By.ID, value: str | None = None) _pkg = ".".join(__name__.split(".")[:-1]) raw_data = pkgutil.get_data(_pkg, "findElements.js") if raw_data is None: - raise FileNotFoundError(f"Could not find findElements.js in package {_pkg}") + raise FileNotFoundError( + f"Could not find findElements.js in package {_pkg}" + ) raw_function = raw_data.decode("utf8") - find_element_js = f"/* findElements */return ({raw_function}).apply(null, arguments);" + find_element_js = ( + f"/* findElements */return ({raw_function}).apply(null, arguments);" + ) return self.execute_script(find_element_js, by.to_dict()) # Return empty list if driver returns null # See https://github.com/SeleniumHQ/selenium/issues/4555 - return self.execute(Command.FIND_ELEMENTS, {"using": by, "value": value})["value"] or [] + return ( + self.execute(Command.FIND_ELEMENTS, {"using": by, "value": value})["value"] + or [] + ) @property def capabilities(self) -> dict: @@ -943,7 +995,9 @@ def get_window_size(self, windowHandle: str = "current") -> dict: return {k: size[k] for k in ("width", "height")} - def set_window_position(self, x: float, y: float, windowHandle: str = "current") -> dict: + def set_window_position( + self, x: float, y: float, windowHandle: str = "current" + ) -> dict: """Sets the x,y position of the current window. Args: @@ -971,7 +1025,10 @@ def get_window_position(self, windowHandle="current") -> dict: def _check_if_window_handle_is_current(self, windowHandle: str) -> None: """Warns if the window handle is not equal to `current`.""" if windowHandle != "current": - warnings.warn("Only 'current' window is supported for W3C compatible browsers.", stacklevel=2) + warnings.warn( + "Only 'current' window is supported for W3C compatible browsers.", + stacklevel=2, + ) def get_window_rect(self) -> dict: """Get the window's position and size. @@ -999,7 +1056,9 @@ def set_window_rect(self, x=None, y=None, width=None, height=None) -> dict: if (x is None and y is None) and (not height and not width): raise InvalidArgumentException("x and y or height and width need values") - return self.execute(Command.SET_WINDOW_RECT, {"x": x, "y": y, "width": width, "height": height})["value"] + return self.execute( + Command.SET_WINDOW_RECT, {"x": x, "y": y, "width": width, "height": height} + )["value"] @property def file_detector(self) -> FileDetector: @@ -1044,7 +1103,9 @@ def orientation(self, value) -> None: if value.upper() in allowed_values: self.execute(Command.SET_SCREEN_ORIENTATION, {"orientation": value}) else: - raise WebDriverException("You can only set the orientation to 'LANDSCAPE' and 'PORTRAIT'") + raise WebDriverException( + "You can only set the orientation to 'LANDSCAPE' and 'PORTRAIT'" + ) def start_devtools(self) -> tuple[Any, WebSocketConnection]: global cdp @@ -1059,7 +1120,9 @@ def start_devtools(self) -> tuple[Any, WebSocketConnection]: version, ws_url = self._get_cdp_details() if not ws_url: - raise WebDriverException("Unable to find url to connect to from capabilities") + raise WebDriverException( + "Unable to find url to connect to from capabilities" + ) if cdp is None: raise WebDriverException("CDP module not loaded") @@ -1068,20 +1131,28 @@ def start_devtools(self) -> tuple[Any, WebSocketConnection]: if self._websocket_connection: return self._devtools, self._websocket_connection if self.caps["browserName"].lower() == "firefox": - raise RuntimeError("CDP support for Firefox has been removed. Please switch to WebDriver BiDi.") + raise RuntimeError( + "CDP support for Firefox has been removed. Please switch to WebDriver BiDi." + ) if not isinstance(self.command_executor, RemoteConnection): - raise WebDriverException("command_executor must be a RemoteConnection instance for CDP support") + raise WebDriverException( + "command_executor must be a RemoteConnection instance for CDP support" + ) self._websocket_connection = WebSocketConnection( ws_url, self.command_executor.client_config.websocket_timeout, self.command_executor.client_config.websocket_interval, ) - targets = self._websocket_connection.execute(self._devtools.target.get_targets()) + targets = self._websocket_connection.execute( + self._devtools.target.get_targets() + ) for target in targets: if target.target_id == self.current_window_handle: target_id = target.target_id break - session = self._websocket_connection.execute(self._devtools.target.attach_to_target(target_id, True)) + session = self._websocket_connection.execute( + self._devtools.target.attach_to_target(target_id, True) + ) self._websocket_connection.session_id = session return self._devtools, self._websocket_connection @@ -1096,7 +1167,9 @@ async def bidi_connection(self): version, ws_url = self._get_cdp_details() if not ws_url: - raise WebDriverException("Unable to find url to connect to from capabilities") + raise WebDriverException( + "Unable to find url to connect to from capabilities" + ) devtools = cdp.import_devtools(version) async with cdp.open_cdp(ws_url) as conn: @@ -1122,10 +1195,14 @@ def _start_bidi(self) -> None: if self.caps.get("webSocketUrl"): ws_url = self.caps.get("webSocketUrl") else: - raise WebDriverException("Unable to find url to connect to from capabilities") + raise WebDriverException( + "Unable to find url to connect to from capabilities" + ) if not isinstance(self.command_executor, RemoteConnection): - raise WebDriverException("command_executor must be a RemoteConnection instance for BiDi support") + raise WebDriverException( + "command_executor must be a RemoteConnection instance for BiDi support" + ) self._websocket_connection = WebSocketConnection( ws_url, @@ -1323,9 +1400,13 @@ def _get_cdp_details(self): http = urllib3.PoolManager() try: if self.caps.get("browserName") == "chrome": - debugger_address = self.caps.get("goog:chromeOptions").get("debuggerAddress") + debugger_address = self.caps.get("goog:chromeOptions").get( + "debuggerAddress" + ) elif self.caps.get("browserName") in ("MicrosoftEdge", "webview2"): - debugger_address = self.caps.get("ms:edgeOptions").get("debuggerAddress") + debugger_address = self.caps.get("ms:edgeOptions").get( + "debuggerAddress" + ) except AttributeError: raise WebDriverException("Can't get debugger address.") @@ -1353,7 +1434,9 @@ def add_virtual_authenticator(self, options: VirtualAuthenticatorOptions) -> Non driver.add_virtual_authenticator(options) ``` """ - self._authenticator_id = self.execute(Command.ADD_VIRTUAL_AUTHENTICATOR, options.to_dict())["value"] + self._authenticator_id = self.execute( + Command.ADD_VIRTUAL_AUTHENTICATOR, options.to_dict() + )["value"] @property def virtual_authenticator_id(self) -> str | None: @@ -1367,7 +1450,10 @@ def remove_virtual_authenticator(self) -> None: The authenticator is no longer valid after removal, so no methods may be called. """ - self.execute(Command.REMOVE_VIRTUAL_AUTHENTICATOR, {"authenticatorId": self._authenticator_id}) + self.execute( + Command.REMOVE_VIRTUAL_AUTHENTICATOR, + {"authenticatorId": self._authenticator_id}, + ) self._authenticator_id = None @required_virtual_authenticator @@ -1382,13 +1468,20 @@ def add_credential(self, credential: Credential) -> None: driver.add_credential(credential) ``` """ - self.execute(Command.ADD_CREDENTIAL, {**credential.to_dict(), "authenticatorId": self._authenticator_id}) + self.execute( + Command.ADD_CREDENTIAL, + {**credential.to_dict(), "authenticatorId": self._authenticator_id}, + ) @required_virtual_authenticator def get_credentials(self) -> list[Credential]: """Returns the list of credentials owned by the authenticator.""" - credential_data = self.execute(Command.GET_CREDENTIALS, {"authenticatorId": self._authenticator_id}) - return [Credential.from_dict(credential) for credential in credential_data["value"]] + credential_data = self.execute( + Command.GET_CREDENTIALS, {"authenticatorId": self._authenticator_id} + ) + return [ + Credential.from_dict(credential) for credential in credential_data["value"] + ] @required_virtual_authenticator def remove_credential(self, credential_id: str | bytearray) -> None: @@ -1403,13 +1496,16 @@ def remove_credential(self, credential_id: str | bytearray) -> None: credential_id = urlsafe_b64encode(credential_id).decode() self.execute( - Command.REMOVE_CREDENTIAL, {"credentialId": credential_id, "authenticatorId": self._authenticator_id} + Command.REMOVE_CREDENTIAL, + {"credentialId": credential_id, "authenticatorId": self._authenticator_id}, ) @required_virtual_authenticator def remove_all_credentials(self) -> None: """Removes all credentials from the authenticator.""" - self.execute(Command.REMOVE_ALL_CREDENTIALS, {"authenticatorId": self._authenticator_id}) + self.execute( + Command.REMOVE_ALL_CREDENTIALS, {"authenticatorId": self._authenticator_id} + ) @required_virtual_authenticator def set_user_verified(self, verified: bool) -> None: @@ -1422,12 +1518,17 @@ def set_user_verified(self, verified: bool) -> None: Example: `driver.set_user_verified(True)` """ - self.execute(Command.SET_USER_VERIFIED, {"authenticatorId": self._authenticator_id, "isUserVerified": verified}) + self.execute( + Command.SET_USER_VERIFIED, + {"authenticatorId": self._authenticator_id, "isUserVerified": verified}, + ) def get_downloadable_files(self) -> list: """Retrieves the downloadable files as a list of file names.""" if "se:downloadsEnabled" not in self.capabilities: - raise WebDriverException("You must enable downloads in order to work with downloadable files.") + raise WebDriverException( + "You must enable downloads in order to work with downloadable files." + ) return self.execute(Command.GET_DOWNLOADABLE_FILES)["value"]["names"] @@ -1442,12 +1543,16 @@ def download_file(self, file_name: str, target_directory: str) -> None: `driver.download_file("example.zip", "/path/to/directory")` """ if "se:downloadsEnabled" not in self.capabilities: - raise WebDriverException("You must enable downloads in order to work with downloadable files.") + raise WebDriverException( + "You must enable downloads in order to work with downloadable files." + ) if not os.path.exists(target_directory): os.makedirs(target_directory) - contents = self.execute(Command.DOWNLOAD_FILE, {"name": file_name})["value"]["contents"] + contents = self.execute(Command.DOWNLOAD_FILE, {"name": file_name})["value"][ + "contents" + ] with tempfile.TemporaryDirectory() as tmp_dir: zip_file = os.path.join(tmp_dir, file_name + ".zip") @@ -1460,7 +1565,9 @@ def download_file(self, file_name: str, target_directory: str) -> None: def delete_downloadable_files(self) -> None: """Deletes all downloadable files.""" if "se:downloadsEnabled" not in self.capabilities: - raise WebDriverException("You must enable downloads in order to work with downloadable files.") + raise WebDriverException( + "You must enable downloads in order to work with downloadable files." + ) self.execute(Command.DELETE_DOWNLOADABLE_FILES) @@ -1566,5 +1673,10 @@ def _check_fedcm() -> Dialog | None: except NoAlertPresentException: return None - wait = WebDriverWait(self, timeout, poll_frequency=poll_frequency, ignored_exceptions=ignored_exceptions) + wait = WebDriverWait( + self, + timeout, + poll_frequency=poll_frequency, + ignored_exceptions=ignored_exceptions, + ) return wait.until(lambda _: _check_fedcm()) diff --git a/py/selenium/webdriver/remote/websocket_connection.py b/py/selenium/webdriver/remote/websocket_connection.py index 98bf4f4b9057a..b4cac118df033 100644 --- a/py/selenium/webdriver/remote/websocket_connection.py +++ b/py/selenium/webdriver/remote/websocket_connection.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import dataclasses import json import logging from ssl import CERT_NONE @@ -25,6 +26,51 @@ from selenium.common import WebDriverException + +def _snake_to_camel(name: str) -> str: + """Convert snake_case field name to camelCase for BiDi protocol.""" + parts = name.split("_") + return parts[0] + "".join(p.title() for p in parts[1:]) + + +class _BiDiEncoder(json.JSONEncoder): + """JSON encoder for BiDi dataclass instances. + + Converts snake_case field names to camelCase, strips ``None`` values, + and flattens a ``properties`` field (e.g. ``PointerCommonProperties``) + directly into its parent action dict as required by the BiDi spec. + """ + + def _convert(self, value): + """Recursively convert a value, handling nested dataclasses, lists, and dicts.""" + if dataclasses.is_dataclass(value) and not isinstance(value, type): + return self.default(value) + if isinstance(value, list): + return [self._convert(item) for item in value] + if isinstance(value, dict): + return {k: self._convert(v) for k, v in value.items()} + return value + + def default(self, o): + if dataclasses.is_dataclass(o) and not isinstance(o, type): + result = {} + for f in dataclasses.fields(o): + value = getattr(o, f.name) + if value is None: + continue + camel_key = _snake_to_camel(f.name) + # Flatten PointerCommonProperties fields inline into the parent + if camel_key == "properties" and dataclasses.is_dataclass(value): + for pf in dataclasses.fields(value): + pv = getattr(value, pf.name) + if pv is not None: + result[_snake_to_camel(pf.name)] = self._convert(pv) + else: + result[camel_key] = self._convert(value) + return result + return super().default(o) + + logger = logging.getLogger(__name__) @@ -63,7 +109,7 @@ def execute(self, command): if self.session_id: payload["sessionId"] = self.session_id - data = json.dumps(payload) + data = json.dumps(payload, cls=_BiDiEncoder) logger.debug(f"-> {data}"[: self._max_log_message_size]) self._ws.send(data) @@ -109,7 +155,9 @@ def _serialize_command(self, command): def _deserialize_result(self, result, command): try: _ = command.send(result) - raise WebDriverException("The command's generator function did not exit when expected!") + raise WebDriverException( + "The command's generator function did not exit when expected!" + ) except StopIteration as exit: return exit.value @@ -126,11 +174,15 @@ def on_error(ws, error): def run_socket(): if self.url.startswith("wss://"): - self._ws.run_forever(sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True) + self._ws.run_forever( + sslopt={"cert_reqs": CERT_NONE}, suppress_origin=True + ) else: self._ws.run_forever(suppress_origin=True) - self._ws = WebSocketApp(self.url, on_open=on_open, on_message=on_message, on_error=on_error) + self._ws = WebSocketApp( + self.url, on_open=on_open, on_message=on_message, on_error=on_error + ) self._ws_thread = Thread(target=run_socket, daemon=True) self._ws_thread.start() diff --git a/py/test/selenium/webdriver/common/bidi_browser_tests.py b/py/test/selenium/webdriver/common/bidi_browser_tests.py index 7fd054b73627d..b9e042403dcba 100644 --- a/py/test/selenium/webdriver/common/bidi_browser_tests.py +++ b/py/test/selenium/webdriver/common/bidi_browser_tests.py @@ -20,7 +20,7 @@ import pytest from selenium.common.exceptions import TimeoutException -from selenium.webdriver.common.bidi.browser import ClientWindowInfo, ClientWindowState +from selenium.webdriver.common.bidi.browser import ClientWindowInfo, ClientWindowNamedState from selenium.webdriver.common.bidi.browsing_context import ReadinessState from selenium.webdriver.common.bidi.session import UserPromptHandler, UserPromptHandlerType from selenium.webdriver.common.by import By @@ -100,10 +100,9 @@ def test_raises_exception_when_removing_default_user_context(driver): def test_client_window_state_constants(driver): - assert ClientWindowState.FULLSCREEN == "fullscreen" - assert ClientWindowState.MAXIMIZED == "maximized" - assert ClientWindowState.MINIMIZED == "minimized" - assert ClientWindowState.NORMAL == "normal" + """Test ClientWindowNamedState constants.""" + assert ClientWindowNamedState.MAXIMIZED == "maximized" + assert ClientWindowNamedState.MINIMIZED == "minimized" def test_create_user_context_with_accept_insecure_certs(driver): @@ -177,7 +176,7 @@ def test_create_user_context_with_manual_proxy_all_params(driver, proxy_server): # Visit a site that should be proxied driver.get("http://example.com/") - body_text = driver.find_element("tag name", "body").text + body_text = driver.find_element(By.TAG_NAME, "body").text assert "proxied response" in body_text.lower() finally: