from catcher.steps.external_step import ExternalStep
from catcher.steps.step import Step, update_variables
from catcher.utils.logger import warning
from catcher.utils.misc import try_get_object, fill_template_str, try_get_objects, fill_template
import ssl
from catcher_modules.mq import MqStepMixin
[docs]class Rabbit(ExternalStep, MqStepMixin):
"""
Allows you to consume/produce messages from/to `RabbitMQ <https://www.rabbitmq.com/>`_
:Input:
:config: rabbitmq config object, used in other rabbitmq commands.
- server: is the rabbit host, <rabbit-host:rabbit-port>
- username: is the username
- password: is the password
- virtualhost: virtualhost *Optional* defaults to "/"
- sslOptions: {'ssl_version': 'PROTOCOL_TLSv1, PROTOCOL_TLSv1_1 or PROTOCOL_TLSv1_2', 'ca_certs': '/path/to/ca_cert', 'keyfile': '/path/to/key', 'certfile': '/path/to/cert'. 'cert_reqs': 'CERT_NONE, CERT_OPTIONAL or CERT_REQUIRED'}
Optional object to be used only when ssl is required.
If an empty object is passed ssl_version defaults to PROTOCOL_TLSv1_2 and cert_reqs defaults to CERT_NONE
- disconnect_timeout: number of seconds to wait for a disconnect before force closing the connection. Warning! Publish
may fail if you use to small timeout value.
:consume: Consume message from rabbit.
- config: rabbitmq config object
- queue: the name of the queue to consume from
:publish: Publish message to rabbit exchange.
- config: rabbitmq config object
- exchange: exchange to publish message
- routing_key: routing key
- headers: headers json *Optional*
- data: data to be produced
- data_from_file: data to be published. File can be used as data source. *Optional* Either `data` or `data_from_file` should present.
:Examples:
Read message
::
variables:
rabbitmq_config:
server: 127.0.0.1:5672
username: 'guest'
password: 'guest'
steps:
- rabbit:
consume:
config: '{{ rabbitmq_config }}''
queue: 'test.catcher.queue'
Publish `data` variable as message
::
variables:
rabbitmq_config:
server: 127.0.0.1:5672
sslOptions: {'ssl_version': 'PROTOCOL_TLSv1, PROTOCOL_TLSv1_1 or PROTOCOL_TLSv1_2', 'ca_certs': '/path/to/ca_cert', 'keyfile': '/path/to/key', 'certfile': '/path/to/cert'. 'cert_reqs': 'CERT_NONE, CERT_OPTIONAL or CERT_REQUIRED'}
username: 'guest'
password: 'guest'
steps:
- rabbit:
publish:
config: '{{ rabbitmq_config }}''
exchange: 'test.catcher.exchange'
routing_key: 'catcher.routing.key'
headers: {'test.header.1': 'header1', 'test.header.2': 'header1'}
data: '{{ data|tojson }}'
Publish `data_from_file` variable as json message
::
variables:
rabbitmq_config:
server: 127.0.0.1:5672
username: 'guest'
password: 'guest'
steps:
- rabbit:
publish:
config: '{{ rabbitmq_config }}''
exchange: 'test.catcher.exchange'
routing_key: 'catcher.routing.key'
data_from_file: '{{ /path/to/file }}'
"""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
method = Step.filter_predefined_keys(kwargs) # publish/consume
self.method = method.lower()
conf = kwargs[method]
self.config = conf['config']
self.headers = conf.get('headers', {})
self.message = None
if self.method != 'consume':
self.exchange = conf['exchange']
self.routing_key = conf['routing_key']
self.message = conf.get('data', None)
self.file = None
if self.message is None:
self.file = conf['data_from_file']
else:
self.queue = conf['queue']
@update_variables
def action(self, includes: dict, variables: dict) -> any:
# if virtual host is not specified default it to /
config = try_get_objects(fill_template_str(self.config, variables))
if config.get('virtualhost') is None:
config['virtualhost'] = ''
disconnect_timeout = int(config.get('disconnect_timeout', 10)) # 10 sec for connection closed exception
connection_parameters = self._get_connection_parameters(config)
if self.method == 'publish':
message = self.form_body(self.message, self.file, variables)
return variables, self.publish(connection_parameters,
fill_template_str(self.exchange, variables),
fill_template_str(self.routing_key, variables),
fill_template(self.headers, variables),
message,
disconnect_timeout)
elif self.method == 'consume':
return variables, self.consume(connection_parameters,
fill_template_str(self.queue, variables),
disconnect_timeout)
else:
raise AttributeError('unknown method: ' + self.method)
@staticmethod
def publish(connection_parameters, exchange, routing_key, headers, message, disconnect_timeout):
import pika
from pika import exceptions
properties = pika.BasicProperties(headers=headers)
try:
connection_parameters.blocked_connection_timeout = disconnect_timeout
with pika.BlockingConnection(connection_parameters) as connection:
channel = connection.channel()
channel.basic_publish(exchange=exchange, routing_key=routing_key, properties=properties, body=message)
except exceptions.ConnectionClosed:
warning('Failed to gracefully close rabbit connection.')
@staticmethod
def consume(connection_parameters, queue, disconnect_timeout):
message = None
import pika
connection_parameters.blocked_connection_timeout = disconnect_timeout
with pika.BlockingConnection(connection_parameters) as connection:
channel = connection.channel()
method_frame, header_frame, body = channel.basic_get(queue)
if isinstance(body, (bytes, bytearray)):
body = body.decode('utf-8')
if method_frame:
channel.basic_ack(method_frame.delivery_tag)
message = try_get_object(body)
return message
def _get_connection_parameters(self, config):
import pika
amqpURL = 'amqp{}://{}:{}@{}/{}'
sslOptions = config.get('sslOptions')
parameters = pika.URLParameters(
amqpURL.format('s' if sslOptions else '', config['username'], config['password'], config['server'],
config['virtualhost']))
if sslOptions is not None:
parameters.ssl_options = self._get_ssl_options(sslOptions)
return parameters
@staticmethod
def _get_ssl_options(ssl_options):
# PROTOCOL_TLSv1, PROTOCOL_TLSv1_1 or PROTOCOL_TLSv1_2
sslVersion = {
'PROTOCOL_TLSv1': ssl.PROTOCOL_TLSv1,
'PROTOCOL_TLSv1_1': ssl.PROTOCOL_TLSv1_1,
'PROTOCOL_TLSv1_2': ssl.PROTOCOL_TLSv1_2
}
# CERT_NONE, CERT_OPTIONAL or CERT_REQUIRED
certReqs = {
'CERT_NONE': ssl.CERT_NONE,
'CERT_OPTIONAL': ssl.CERT_OPTIONAL,
'CERT_REQUIRED': ssl.CERT_REQUIRED,
}
import pika
context = ssl.SSLContext(sslVersion.get(ssl_options.get('ssl_version'), ssl.PROTOCOL_TLSv1_2))
context.verify_mode = certReqs.get(ssl_options.get('cert_reqs'), 'CERT_NONE')
context.keylog_filename = ssl_options.get('keyfile')
if ssl_options.get('ca_certs') is not None:
context.load_verify_locations(ssl_options.get('ca_certs'), None, None)
if ssl_options.get('certfile') is not None:
context.load_cert_chain(ssl_options.get('certfile'))
return pika.SSLOptions(context=context)