Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions discord/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
# fmt: off
__all__ = (
'Client',
'CacheOptions',
)
# fmt: on

Expand All @@ -150,6 +151,173 @@ def __getattr__(self, attr: str) -> None:
_loop: Any = _LoopSentinel()


class CacheOptions:
"""Represents a :class:`Client` cache control.

Using this may allow you to control what things you want your client
to cache.

.. warning::

Using this may result in unexpected behaviour in certain parts of the
library. For example, disabling ``guilds`` may cause problems on objects
with a ``.guild`` attribute returning ``None``.

All parameters default to ``True`` unless explicitly set to ``False``.

Parameters
----------
guilds: :class:`bool`
Whether to cache guilds. Defaults to whether :attr:`Intents.guilds` is enabled
or not.
users: :class:`bool`
Whether to cache users.
members: :class:`bool`
Whether to cache members. Defaults to whether :attr:`Intents.members` is enabled
or not.
presences: :class:`bool`
Whether to cache members. Defaults to whether :attr:`Intents.presences` is enabled
or not.
voice_states: :class:`bool`
Whether to cache voice states. Defaults to whether :attr:`Intents.voice_states` is enabled
or not.
emojis_and_stickers: :class:`bool`
Whether to cache emojis and stickers. Defaults to whether :attr:`Intents.emojis_and_stickers` is enabled
or not.
soundboard_sounds: :class:`bool`
Whether to cache guild's soundboard sounds.
private_channels: :class:`bool`
Whether to cache private channels (aka DMs/Group DMS).
guild_channels: :class:`bool`
Whether to cache guild channels.
roles: :class:`bool`
Whether to cache guild roles.
threads: :class:`bool`
Whether to cache threads.
scheduled_events: :class:`bool`
Whether to cache scheduled events.
stage_instances: :class:`bool`
Whether to cache stage instances.
"""

__valid_flags__ = (
'guilds',
'users',
'members',
'presences',
'voice_states',
'emojis_and_stickers',
'soundboard_sounds',
'private_channels',
'guild_channels',
'roles',
'threads',
'scheduled_events',
'stage_instances',
)

if TYPE_CHECKING:
guilds: bool
users: bool
members: bool
presences: bool
voice_states: bool
emojis_and_stickers: bool
soundboard_sounds: bool
private_channels: bool
guild_channels: bool
roles: bool
threads: bool
scheduled_events: bool
stage_instances: bool

@overload
def __init__(
self,
*,
guilds: bool = MISSING,
users: bool = MISSING,
members: bool = MISSING,
presences: bool = MISSING,
voice_states: bool = MISSING,
emojis_and_stickers: bool = MISSING,
soundboard_sounds: bool = MISSING,
private_channels: bool = MISSING,
guild_channels: bool = MISSING,
roles: bool = MISSING,
threads: bool = MISSING,
scheduled_events: bool = MISSING,
stage_instances: bool = MISSING,
) -> None: ...

@overload
def __init__(
self,
**kwargs: bool,
) -> None: ...

def __init__(
self,
**kwargs: bool,
) -> None:
self._cache_data = kwargs

def __getattr__(self, name: str) -> Any:
val = self._cache_data.get(name)
if val is not None:
return val if val is not MISSING else False
if name in self.__valid_flags__:
return True # the flag has not been set, but assume they want it enabled
raise AttributeError(f'attribute {name!r} does not exist for {self.__class__.__name__!r}')

def __setattr__(self, name: str, value: bool) -> None:
if name not in self.__valid_flags__:
raise AttributeError(f'attribute {name!r} does not exist for {self.__class__.__name__!r}')
self._cache_data[name] = value

@classmethod
def from_intents(cls, intents: Intents, /) -> CacheOptions:
"""Creates a new :class:`CacheOptions` instance from a :class:`Intents`
one.

Parameters
----------
intents: :class:`Intents`
The intents to use.

Returns
-------
:class:`CacheOptions`
A new instance of cache options.
"""

return cls(
guilds=intents.guilds,
members=intents.members,
presences=intents.presences,
voice_states=intents.voice_states,
emojis_and_stickers=intents.emojis_and_stickers,
)

@classmethod
def none(cls) -> CacheOptions:
"""Creates a new :class:`CacheOptions` instance with nothing enabled.

Returns
-------
:class:`CacheOptions`
A new instance of cache options.
"""
return cls(
**{flag: False for flag in cls.__valid_flags__}
)

def _update_from_intents(self, intents: Intents) -> None:
for key in ('guilds', 'members', 'presences', 'voice_states', 'emojis_and_stickers'):
ret = getattr(self, key, False)
self._cache_data[key] = ret or getattr(intents, key, False)


class Client:
r"""Represents a client connection that connects to Discord.
This class is used to interact with the Discord WebSocket and API.
Expand Down Expand Up @@ -265,6 +433,10 @@ class Client:
behavior, such as setting a dns resolver or sslcontext.

.. versionadded:: 2.5
global_ratelimit_sleep_time: :class:`float`
The time to sleep to prevent global ratelimits. Defaults to ``5``.
cache_options: :class:`CacheOptions`
The cache options of this client.

Attributes
-----------
Expand All @@ -286,6 +458,7 @@ def __init__(self, *, intents: Intents, **options: Any) -> None:
unsync_clock: bool = options.pop('assume_unsync_clock', True)
http_trace: Optional[aiohttp.TraceConfig] = options.pop('http_trace', None)
max_ratelimit_timeout: Optional[float] = options.pop('max_ratelimit_timeout', None)
global_sleep_time: Optional[float] = options.pop('global_ratelimit_sleep_time', None)
self.http: HTTPClient = HTTPClient(
self.loop,
connector,
Expand All @@ -294,6 +467,7 @@ def __init__(self, *, intents: Intents, **options: Any) -> None:
unsync_clock=unsync_clock,
http_trace=http_trace,
max_ratelimit_timeout=max_ratelimit_timeout,
global_sleep_time=global_sleep_time,
)

self._handlers: Dict[str, Callable[..., None]] = {
Expand All @@ -304,6 +478,10 @@ def __init__(self, *, intents: Intents, **options: Any) -> None:
'before_identify': self._call_before_identify_hook,
}

cache_options: Optional[CacheOptions] = options.get('cache_options')
if cache_options is None:
options['cache_options'] = CacheOptions.from_intents(intents)

self._enable_debug_events: bool = options.pop('enable_debug_events', False)
self._connection: ConnectionState[Self] = self._get_state(intents=intents, **options)
self._connection.shard_count = self.shard_count
Expand Down
16 changes: 8 additions & 8 deletions discord/guild.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,13 +458,13 @@ class Guild(Hashable):
}

def __init__(self, *, data: GuildPayload, state: ConnectionState) -> None:
self._channels: Dict[int, GuildChannel] = {}
self._members: Dict[int, Member] = {}
self._voice_states: Dict[int, VoiceState] = {}
self._threads: Dict[int, Thread] = {}
self._stage_instances: Dict[int, StageInstance] = {}
self._scheduled_events: Dict[int, ScheduledEvent] = {}
self._soundboard_sounds: Dict[int, SoundboardSound] = {}
self._channels: Dict[int, GuildChannel] = {} if state.cache_options.guild_channels else state.create_immutable_dict()
self._members: Dict[int, Member] = {} if state.cache_options.members else state.create_immutable_dict()
self._voice_states: Dict[int, VoiceState] = {} if state.cache_options.voice_states else state.create_immutable_dict()
self._threads: Dict[int, Thread] = {} if state.cache_options.threads else state.create_immutable_dict()
self._stage_instances: Dict[int, StageInstance] = {} if state.cache_options.stage_instances else state.create_immutable_dict()
self._scheduled_events: Dict[int, ScheduledEvent] = {} if state.cache_options.scheduled_events else state.create_immutable_dict()
self._soundboard_sounds: Dict[int, SoundboardSound] = {} if state.cache_options.soundboard_sounds else state.create_immutable_dict()
self._state: ConnectionState = state
self._member_count: Optional[int] = None
self._from_data(data)
Expand Down Expand Up @@ -589,7 +589,7 @@ def _from_data(self, guild: GuildPayload) -> None:
self._banner: Optional[str] = guild.get('banner')
self.unavailable: bool = guild.get('unavailable', False)
self.id: int = int(guild['id'])
self._roles: Dict[int, Role] = {}
self._roles: Dict[int, Role] = {} if self._state.cache_options.roles else self._state.create_immutable_dict()
state = self._state # speed up attribute access
for r in guild.get('roles', []):
role = Role(guild=self, data=r, state=state)
Expand Down
23 changes: 23 additions & 0 deletions discord/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ def __init__(
unsync_clock: bool = True,
http_trace: Optional[aiohttp.TraceConfig] = None,
max_ratelimit_timeout: Optional[float] = None,
global_sleep_time: Optional[float] = None,
) -> None:
self.loop: asyncio.AbstractEventLoop = loop
self.connector: aiohttp.BaseConnector = connector or MISSING
Expand All @@ -539,6 +540,8 @@ def __init__(

user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}'
self.user_agent: str = user_agent.format(__version__, sys.version_info, aiohttp.__version__)
self._request_history: deque[float] = deque()
self.global_sleep_time: float = global_sleep_time if global_sleep_time is not None else 5

def clear(self) -> None:
if self.__session and self.__session.closed:
Expand Down Expand Up @@ -575,6 +578,15 @@ def get_ratelimit(self, key: str) -> Ratelimit:
self._try_clear_expired_ratelimits()
return value

def clear_global_requests(self, now: Optional[float] = None) -> None:
if now is None:
now = self.loop.time()

self._request_history = deque(
t for t in self._request_history
if t > (now - 1)
)

async def request(
self,
route: Route,
Expand Down Expand Up @@ -633,6 +645,17 @@ async def request(
data: Optional[Union[Dict[str, Any], str]] = None
async with ratelimit:
for tries in range(5):
now = self.loop.time()
if len(self._request_history) + 1 >= 50:
first = self._request_history.popleft()

if now - first <= 1:
_log.info(f'Sleeping {self.global_sleep_time} seconds to prevent global ratelimit...')
await asyncio.sleep(self.global_sleep_time)
self.clear_global_requests(now)

self._request_history.append(now)

if files:
for f in files:
f.reset(seek=tries)
Expand Down
Loading