import asyncio
from typing import Optional, Callable, Any, Union
import botocore.exceptions
import ujson as json
from aioredlock import Aioredlock, LockError
from inspect import signature, isawaitable, isfunction
from insanic.conf import settings
from insanic.exceptions import ImproperlyConfigured
# from insanic.log import logger, error_logger
from iniesta.exceptions import StopPolling
from iniesta.log import logger, error_logger
from iniesta.sessions import BotoSession
from iniesta.sns import SNSClient
from iniesta.utils import filter_list_to_filter_policies
from .message import SQSMessage
default = object()
[docs]class SQSClient:
endpoint_url = None
lock_key = "sqs:event:{message_id}"
handlers = {} # dict with {event: handler function}
queue_urls = {} # dict with {queue_name: queue_url}
def __init__(
self,
*,
queue_name: str = None,
endpoint_url: str = None,
region_name: str = None,
retry_count: int = None,
lock_timeout: int = None,
):
"""
Initializes a SQSClient instance
:param queue_name: If None, defaults to INIESTA_SQS_QUEUE_NAME_TEMPLATE
:param retry_count: retry count for aioredlock, defaults to ``INIESTA_LOCK_RETRY_COUNT``
:param lock_timeout: lock timeout for aioredlock. Defaults to ``INIESTA_LOCK_TIMEOUT``
:raise KeyError: If application was not initialized with one of the initialization methods.
"""
self.queue_name = (
self.default_queue_name() if queue_name is None else queue_name
)
self.region_name = region_name or BotoSession.aws_default_region
try:
self.queue_url = self.queue_urls[self.queue_name]
except KeyError:
error_logger.error(
f"Please use initialize to initialize queue: {queue_name}"
)
raise
self.endpoint_url = endpoint_url or getattr(
settings, "INIESTA_SQS_ENDPOINT_URL", None
)
self._filters = None
retry_count = retry_count or settings.INIESTA_LOCK_RETRY_COUNT
lock_timeout = lock_timeout or settings.INIESTA_LOCK_TIMEOUT
# TODO: get connection info from insanic get connection
connections = []
for cache_name, conn_info in settings.INSANIC_CACHES.items():
if cache_name.startswith("iniesta"):
connections.append(
"redis://{HOST}:{PORT}/{DATABASE}".format(**conn_info)
)
self.lock_manager = Aioredlock(
connections,
retry_count=retry_count,
internal_lock_timeout=lock_timeout,
)
@classmethod
def default_queue_name(cls) -> str:
return (
settings.INIESTA_SQS_QUEUE_NAME
or settings.INIESTA_SQS_QUEUE_NAME_TEMPLATE.format(
env=settings.ENVIRONMENT, service_name=settings.SERVICE_NAME
)
)
[docs] @classmethod
async def initialize(
cls,
*,
queue_name: Optional[str] = None,
endpoint_url: Optional[str] = None,
region_name: Optional[str] = None,
):
"""
The initialization classmethod that should be first run before any subsequent SQSClient initializations.
:param queue_name: queue_name if want to initialize client with a different queue
:rtype: :code:`SQSClient`
"""
session = BotoSession.get_session()
endpoint_url = endpoint_url or getattr(
settings, "INIESTA_SQS_ENDPOINT_URL", None
)
if queue_name is None:
queue_name = cls.default_queue_name()
# check if queue exists
if queue_name not in cls.queue_urls:
try:
async with session.create_client(
"sqs",
region_name=region_name or BotoSession.aws_default_region,
endpoint_url=endpoint_url,
aws_access_key_id=BotoSession.aws_access_key_id,
aws_secret_access_key=BotoSession.aws_secret_access_key,
) as client:
response = await client.get_queue_url(QueueName=queue_name)
except botocore.exceptions.ClientError as e:
error_message = f"[{e.response['Error']['Code']}]: {e.response['Error']['Message']} {queue_name}"
error_logger.critical(error_message)
raise
else:
queue_url = response["QueueUrl"]
cls.queue_urls.update({queue_name: queue_url})
sqs_client = cls(queue_name=queue_name)
# check if subscription exists
# await cls._confirm_subscription(sqs_client, topic_arn, endpoint_url)
return sqs_client
[docs] async def confirm_subscription(self, topic_arn: str) -> None:
"""
Confirms the correct subscriptions are in place in AWS SNS
:param topic_arn: Topic to check subscriptions for.
:raises EnvironmentError: If the the queue is not found.
:raises AssertionError: If the registered filters on AWS do not match current config filters.
"""
sns_client = SNSClient(topic_arn)
subscriptions = sns_client.list_subscriptions_by_topic()
subscription_list = []
async for subs in subscriptions:
subscription_list.append(subs)
if self.queue_name in subs.get("Endpoint", "").split(":"):
service_subscriptions = subs
break
else:
raise EnvironmentError(
f"Unable to find subscription for {settings.SERVICE_NAME}"
)
if settings.INIESTA_ASSERT_FILTER_POLICIES:
# check if filters match specified
subscription_attributes = await sns_client.get_subscription_attributes(
subscription_arn=service_subscriptions["SubscriptionArn"]
)
filter_policies = json.loads(
subscription_attributes["Attributes"].get("FilterPolicy", "{}")
)
if filter_policies != self.filters:
raise AssertionError(
f"Subscription filters and current filters are not equivalent. "
f"{filter_policies} {self.filters}"
)
[docs] async def confirm_permission(self) -> None:
"""
Confirms correct permissions are in place.
:raises ImproperlyConfigured: If the permissions were not found.
:raises AssertionError: If the permissions are not correctly configured on AWS.
"""
session = BotoSession.get_session()
async with session.create_client(
"sqs",
region_name=self.region_name or BotoSession.aws_default_region,
endpoint_url=self.endpoint_url,
aws_access_key_id=BotoSession.aws_access_key_id,
aws_secret_access_key=BotoSession.aws_secret_access_key,
) as client:
policy_attributes = await client.get_queue_attributes(
QueueUrl=self.queue_url, AttributeNames=["Policy"]
)
try:
policies = json.loads(policy_attributes["Attributes"]["Policy"])
statement = policies["Statement"][0]
except KeyError:
raise ImproperlyConfigured("Permissions not found.")
# need "Effect": "Allow", "Action": "SQS:SendMessage"
assert statement["Effect"] == "Allow"
assert "SQS:SendMessage" in statement["Action"]
# assert statement['Condition']['ArnEquals']['aws:SourceArn'] == topic_arn
@property
def filters(self) -> dict:
if self._filters is None:
self._filters = filter_list_to_filter_policies(
settings.INIESTA_SNS_EVENT_KEY,
settings.INIESTA_SQS_CONSUMER_FILTERS,
)
return self._filters
[docs] def start_receiving_messages(self, loop=None) -> None:
"""
Method to start polling for messages.
"""
self._receive_messages = True
if loop is None:
loop = asyncio.get_event_loop()
self._polling_task = asyncio.ensure_future(self._poll())
self._loop = loop
[docs] async def stop_receiving_messages(self) -> None:
"""
Method to stop polling
"""
self._receive_messages = False
await self.lock_manager.destroy()
self._polling_task.cancel()
[docs] async def handle_message(self, message: SQSMessage) -> tuple:
"""
Method that hold logic to handle a certain type of mesage
:param message: Message to handle
:raises LockError: If lock could not be acquired for the message
:raises Exception: General exception handler attaches the message and message handler
:return: Returns a tuple of the message and result of the handler
"""
lock = None
try:
lock = await self.lock_manager.lock(
self.lock_key.format(message_id=message.message_id)
)
if not lock.valid:
raise LockError(
f"Could not acquire lock for {message.message_id}"
)
if message.event in self.handlers:
handler = self.handlers[message.event]
elif default in self.handlers:
handler = self.handlers[default]
else:
raise KeyError(f"{message.event} handler not found!")
except Exception as e:
e.message = message
e.handler = None
raise e
else:
try:
result = handler(message)
if isawaitable(result):
result = await result
return message, result
except Exception as e:
e.message = message
e.handler = handler
raise e
finally:
if lock:
await self.lock_manager.unlock(lock)
[docs] def handle_error(self, exc: Exception) -> None:
"""
If an exception occured while handling the message, log the error.
"""
message = exc.message
handler = getattr(exc, "handler", None)
extra = {
"iniesta_pass": message.event,
"sqs_message_id": message.message_id,
"sqs_receipt_handle": message.receipt_handle,
"sqs_md5_of_body": message.md5_of_body,
"sqs_message_body": message.raw_body,
"sqs_attributes": json.dumps(message.attributes),
"handler_name": handler.__qualname__ if handler else None,
}
error_logger.critical(
f"[INIESTA] Error while handling message: {str(exc)}",
exc_info=exc,
extra=extra,
)
[docs] async def handle_success(self, client, message: SQSMessage) -> dict:
"""
Success handler for a message. Deletes the message from SQS.
:param client: aws sqs client
:return: Returns the response of the delete_message request.
"""
message_id = message.message_id
# if success must delete message from sqs
logger.info(
f"[INIESTA] Message handled successfully: msg_id={message_id}",
extra={"sqs_message_id": message_id},
)
resp = await client.delete_message(
QueueUrl=self.queue_url, ReceiptHandle=message.receipt_handle
)
logger.debug(
f"[INIESTA] Message deleted: msg_id={message_id} "
f"receipt_handle={message.receipt_handle}",
extra={"sqs_message_id": message_id},
)
return resp
async def _poll(self) -> str:
"""
The long running method that consistently polls the SQS queue for
messages.
:return:
"""
session = BotoSession.get_session()
async with session.create_client(
"sqs",
region_name=self.region_name,
endpoint_url=self.endpoint_url,
aws_access_key_id=BotoSession.aws_access_key_id,
aws_secret_access_key=BotoSession.aws_secret_access_key,
) as client:
try:
while self._loop.is_running() and self._receive_messages:
try:
response = await client.receive_message(
QueueUrl=self.queue_url,
MaxNumberOfMessages=settings.INIESTA_SQS_RECEIVE_MESSAGE_MAX_NUMBER_OF_MESSAGES,
WaitTimeSeconds=settings.INIESTA_SQS_RECEIVE_MESSAGE_WAIT_TIME_SECONDS,
AttributeNames=["All"],
MessageAttributeNames=["All"],
)
except botocore.exceptions.ClientError as e:
error_logger.critical(
f"[INIESTA] [{e.response['Error']['Code']}]: {e.response['Error']['Message']}"
)
else:
event_tasks = [
asyncio.ensure_future(
self.handle_message(
SQSMessage.from_sqs(client, message)
)
)
for message in response.get("Messages", [])
]
for fut in asyncio.as_completed(event_tasks):
# NOTE: must catch CancelledError and raise
try:
message_obj, result = await fut
except asyncio.CancelledError:
raise
except Exception as e:
# if error log failure and pass so sqs message persists and message becomes visible again
self.handle_error(e)
else:
await self.handle_success(client, message_obj)
await self.hook_post_receive_message_handler()
except asyncio.CancelledError:
logger.info("[INIESTA] POLLING TASK CANCELLED")
return "Cancelled"
except StopPolling:
# mainly used for tests
logger.info("[INIESTA] STOP POLLING")
return "Stopped"
except Exception:
if self._receive_messages and self._loop.is_running():
error_logger.critical("[INIESTA] POLLING TASK RESTARTING")
self._polling_task = asyncio.ensure_future(self._poll())
error_logger.exception("[INIESTA] POLLING EXCEPTION CAUGHT")
finally:
await client.close()
return "Shutdown" # pragma: no cover
[docs] @classmethod
def handler(
cls, event: Union[Callable, str, list, tuple] = None
) -> Callable:
"""
Decorator for attaching a message handler for an event or if None, a default handler.
"""
if event and isfunction(event):
cls.add_handler(event, default)
return event
else:
def register_handler(func):
cls.add_handler(func, default if event is None else event)
return func
return register_handler
[docs] @classmethod
def add_handler(
cls, handler: Callable, event: Union[str, list, tuple] = default
) -> None:
"""
Method for manually declaring a handler for event(s).
:param handler: A function to execute
:param event: The event(or a list of event) the function is attached to.
"""
cls._validate_handler_signature(handler)
if isinstance(event, list) or isinstance(event, tuple):
cls._validate_event_iterable(event)
for e in event:
cls._add_handler(handler, e)
else:
cls._validate_event_name(event)
cls._add_handler(handler, event)
@classmethod
def _validate_event_iterable(cls, events):
if len(set(events)) != len(events):
raise ValueError("Duplication found in list of event")
for e in events:
cls._validate_event_name(e)
@classmethod
def _validate_event_name(cls, event):
if event in cls.handlers.keys():
raise ValueError(f"Handler for event [{event}] already exists.")
@classmethod
def _validate_handler_signature(cls, handler):
args = signature(handler).parameters
if not args:
raise ValueError(
f"Required parameter `message` missing "
f"in the {handler.__name__}() route?"
)
@classmethod
def _add_handler(cls, handler, event):
cls.handlers.update({event: handler})
async def hook_post_receive_message_handler(self): # pragma: no cover
pass
[docs] def create_message(self, message: Any) -> SQSMessage:
"""
A helper method to create an SQSMessage
:param message: The message body. A json encodable object.
"""
if not isinstance(message, str):
message = json.dumps(message)
return SQSMessage(self, message)