From 3d4bdcc49829e8defcf76f970f1d9ce137f3b7f0 Mon Sep 17 00:00:00 2001 From: Alex Root Junior Date: Fri, 23 Aug 2019 23:29:59 +0300 Subject: [PATCH] Fix `Dispatcher.throttle(...)` and rename user & chat arguments to user_id & chat_id --- aiogram/dispatcher/dispatcher.py | 59 ++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/aiogram/dispatcher/dispatcher.py b/aiogram/dispatcher/dispatcher.py index a5bf5b9f..6891f8be 100644 --- a/aiogram/dispatcher/dispatcher.py +++ b/aiogram/dispatcher/dispatcher.py @@ -8,6 +8,7 @@ import typing import aiohttp from aiohttp.helpers import sentinel +from aiogram.utils.deprecated import renamed_argument from .filters import Command, ContentTypeFilter, ExceptionsFilter, FiltersFactory, HashTag, Regexp, \ RegexpCommandsFilter, StateFilter, Text, IDFilter, AdminFilter from .handler import Handler @@ -914,15 +915,17 @@ class Dispatcher(DataMixin, ContextInstanceMixin): return FSMContext(storage=self.storage, chat=chat, user=user) - async def throttle(self, key, *, rate=None, user=None, chat=None, no_error=None) -> bool: + @renamed_argument(old_name='user', new_name='user_id', until_version='3.0', stacklevel=4) + @renamed_argument(old_name='chat', new_name='chat_id', until_version='3.0', stacklevel=4) + async def throttle(self, key, *, rate=None, user_id=None, chat_id=None, no_error=None) -> bool: """ Execute throttling manager. Returns True if limit has not exceeded otherwise raises ThrottleError or returns False :param key: key in storage :param rate: limit (by default is equal to default rate limit) - :param user: user id - :param chat: chat id + :param user_id: user id + :param chat_id: chat id :param no_error: return boolean value instead of raising error :return: bool """ @@ -933,14 +936,14 @@ class Dispatcher(DataMixin, ContextInstanceMixin): no_error = self.no_throttle_error if rate is None: rate = self.throttling_rate_limit - if user is None and chat is None: - user = types.User.get_current() - chat = types.Chat.get_current() + if user_id is None and chat_id is None: + user_id = types.User.get_current().id + chat_id = types.Chat.get_current().id # Detect current time now = time.time() - bucket = await self.storage.get_bucket(chat=chat, user=user) + bucket = await self.storage.get_bucket(chat=chat_id, user=user_id) # Fix bucket if bucket is None: @@ -964,53 +967,57 @@ class Dispatcher(DataMixin, ContextInstanceMixin): else: data[EXCEEDED_COUNT] = 1 bucket[key].update(data) - await self.storage.set_bucket(chat=chat, user=user, bucket=bucket) + await self.storage.set_bucket(chat=chat_id, user=user_id, bucket=bucket) if not result and not no_error: # Raise if it is allowed - raise Throttled(key=key, chat=chat, user=user, **data) + raise Throttled(key=key, chat=chat_id, user=user_id, **data) return result - async def check_key(self, key, chat=None, user=None): + @renamed_argument('user', 'user_id', '3.0') + @renamed_argument('chat', 'chat_id', '3.0') + async def check_key(self, key, chat_id=None, user_id=None): """ Get information about key in bucket :param key: - :param chat: - :param user: + :param chat_id: + :param user_id: :return: """ if not self.storage.has_bucket(): raise RuntimeError('This storage does not provide Leaky Bucket') - if user is None and chat is None: - user = types.User.get_current() - chat = types.Chat.get_current() + if user_id is None and chat_id is None: + user_id = types.User.get_current() + chat_id = types.Chat.get_current() - bucket = await self.storage.get_bucket(chat=chat, user=user) + bucket = await self.storage.get_bucket(chat=chat_id, user=user_id) data = bucket.get(key, {}) - return Throttled(key=key, chat=chat, user=user, **data) + return Throttled(key=key, chat=chat_id, user=user_id, **data) - async def release_key(self, key, chat=None, user=None): + @renamed_argument('user', 'user_id', '3.0') + @renamed_argument('chat', 'chat_id', '3.0') + async def release_key(self, key, chat_id=None, user_id=None): """ Release blocked key :param key: - :param chat: - :param user: + :param chat_id: + :param user_id: :return: """ if not self.storage.has_bucket(): raise RuntimeError('This storage does not provide Leaky Bucket') - if user is None and chat is None: - user = types.User.get_current() - chat = types.Chat.get_current() + if user_id is None and chat_id is None: + user_id = types.User.get_current() + chat_id = types.Chat.get_current() - bucket = await self.storage.get_bucket(chat=chat, user=user) + bucket = await self.storage.get_bucket(chat=chat_id, user=user_id) if bucket and key in bucket: del bucket['key'] - await self.storage.set_bucket(chat=chat, user=user, bucket=bucket) + await self.storage.set_bucket(chat=chat_id, user=user_id, bucket=bucket) return True return False @@ -1086,7 +1093,7 @@ class Dispatcher(DataMixin, ContextInstanceMixin): async def wrapped(*args, **kwargs): is_not_throttled = await self.throttle(key if key is not None else func.__name__, rate=rate, - user=user_id, chat=chat_id, + user_id=user_id, chat_id=chat_id, no_error=True) if is_not_throttled: return await func(*args, **kwargs)