mirror of
https://github.com/ZwareBear/awx.git
synced 2026-04-14 03:51:49 -05:00
replace celery task decorators with a kombu-based publisher
this commit implements the bulk of `awx-manage run_dispatcher`, a new command that binds to RabbitMQ via kombu and balances messages across a pool of workers that are similar to celeryd workers in spirit. Specifically, this includes: - a new decorator, `awx.main.dispatch.task`, which can be used to decorate functions or classes so that they can be designated as "Tasks" - support for fanout/broadcast tasks (at this point in time, only `conf.Setting` memcached flushes use this functionality) - support for job reaping - support for success/failure hooks for job runs (i.e., `handle_work_success` and `handle_work_error`) - support for auto scaling worker pool that scale processes up and down on demand - minimal support for RPC, such as status checks and pool recycle/reload
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
def get_local_queuename():
|
||||
return settings.CLUSTER_HOST_ID.encode('utf-8')
|
||||
|
||||
60
awx/main/dispatch/control.py
Normal file
60
awx/main/dispatch/control.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import logging
|
||||
import socket
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
from awx.main.dispatch import get_local_queuename
|
||||
from kombu import Connection, Queue, Exchange, Producer, Consumer
|
||||
|
||||
logger = logging.getLogger('awx.main.dispatch')
|
||||
|
||||
|
||||
class Control(object):
|
||||
|
||||
services = ('dispatcher', 'callback_receiver')
|
||||
result = None
|
||||
|
||||
def __init__(self, service):
|
||||
if service not in self.services:
|
||||
raise RuntimeError('{} must be in {}'.format(service, self.services))
|
||||
self.service = service
|
||||
queuename = get_local_queuename()
|
||||
self.queue = Queue(queuename, Exchange(queuename), routing_key=queuename)
|
||||
|
||||
def publish(self, msg, conn, host, **kwargs):
|
||||
producer = Producer(
|
||||
exchange=self.queue.exchange,
|
||||
channel=conn,
|
||||
routing_key=get_local_queuename()
|
||||
)
|
||||
producer.publish(msg, expiration=5, **kwargs)
|
||||
|
||||
def status(self, *args, **kwargs):
|
||||
return self.control_with_reply('status', *args, **kwargs)
|
||||
|
||||
def running(self, *args, **kwargs):
|
||||
return self.control_with_reply('running', *args, **kwargs)
|
||||
|
||||
def control_with_reply(self, command, host=None, timeout=5):
|
||||
host = host or settings.CLUSTER_HOST_ID
|
||||
logger.warn('checking {} {} for {}'.format(self.service, command, host))
|
||||
reply_queue = Queue(name="amq.rabbitmq.reply-to")
|
||||
self.result = None
|
||||
with Connection(settings.BROKER_URL) as conn:
|
||||
with Consumer(conn, reply_queue, callbacks=[self.process_message], no_ack=True):
|
||||
self.publish({'control': command}, conn, host, reply_to='amq.rabbitmq.reply-to')
|
||||
try:
|
||||
conn.drain_events(timeout=timeout)
|
||||
except socket.timeout:
|
||||
logger.error('{} did not reply within {}s'.format(self.service, timeout))
|
||||
raise
|
||||
return self.result
|
||||
|
||||
def control(self, msg, host=None, **kwargs):
|
||||
host = host or settings.CLUSTER_HOST_ID
|
||||
with Connection(settings.BROKER_URL) as conn:
|
||||
self.publish(msg, conn, host)
|
||||
|
||||
def process_message(self, body, message):
|
||||
self.result = body
|
||||
message.ack()
|
||||
@@ -1,81 +1,260 @@
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import random
|
||||
import traceback
|
||||
from uuid import uuid4
|
||||
|
||||
import collections
|
||||
from multiprocessing import Process
|
||||
from multiprocessing import Queue as MPQueue
|
||||
from Queue import Full as QueueFull
|
||||
from Queue import Full as QueueFull, Empty as QueueEmpty
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import connection as django_connection
|
||||
from django.core.cache import cache as django_cache
|
||||
from jinja2 import Template
|
||||
import psutil
|
||||
|
||||
from awx.main.models import UnifiedJob
|
||||
from awx.main.dispatch import reaper
|
||||
|
||||
logger = logging.getLogger('awx.main.dispatch')
|
||||
|
||||
|
||||
def signame(sig):
|
||||
return dict(
|
||||
(k, v) for v, k in signal.__dict__.items()
|
||||
if v.startswith('SIG') and not v.startswith('SIG_')
|
||||
)[sig]
|
||||
class PoolWorker(object):
|
||||
'''
|
||||
Used to track a worker child process and its pending and finished messages.
|
||||
|
||||
This class makes use of two distinct multiprocessing.Queues to track state:
|
||||
|
||||
- self.queue: this is a queue which represents pending messages that should
|
||||
be handled by this worker process; as new AMQP messages come
|
||||
in, a pool will put() them into this queue; the child
|
||||
process that is forked will get() from this queue and handle
|
||||
received messages in an endless loop
|
||||
- self.finished: this is a queue which the worker process uses to signal
|
||||
that it has finished processing a message
|
||||
|
||||
When a message is put() onto this worker, it is tracked in
|
||||
self.managed_tasks.
|
||||
|
||||
Periodically, the worker will call .calculate_managed_tasks(), which will
|
||||
cause messages in self.finished to be removed from self.managed_tasks.
|
||||
|
||||
In this way, self.managed_tasks represents a view of the messages assigned
|
||||
to a specific process. The message at [0] is the least-recently inserted
|
||||
message, and it represents what the worker is running _right now_
|
||||
(self.current_task).
|
||||
|
||||
A worker is "busy" when it has at least one message in self.managed_tasks.
|
||||
It is "idle" when self.managed_tasks is empty.
|
||||
'''
|
||||
|
||||
def __init__(self, queue_size, target, args):
|
||||
self.messages_sent = 0
|
||||
self.messages_finished = 0
|
||||
self.managed_tasks = collections.OrderedDict()
|
||||
self.finished = MPQueue(queue_size)
|
||||
self.queue = MPQueue(queue_size)
|
||||
self.process = Process(target=target, args=(self.queue, self.finished) + args)
|
||||
self.process.daemon = True
|
||||
|
||||
def start(self):
|
||||
self.process.start()
|
||||
|
||||
def put(self, body):
|
||||
uuid = '?'
|
||||
if isinstance(body, dict):
|
||||
if not body.get('uuid'):
|
||||
body['uuid'] = str(uuid4())
|
||||
uuid = body['uuid']
|
||||
logger.debug('delivered {} to worker[{}] qsize {}'.format(
|
||||
uuid, self.pid, self.qsize
|
||||
))
|
||||
self.managed_tasks[uuid] = body
|
||||
self.queue.put(body, block=True, timeout=5)
|
||||
self.messages_sent += 1
|
||||
self.calculate_managed_tasks()
|
||||
|
||||
def quit(self):
|
||||
'''
|
||||
Send a special control message to the worker that tells it to exit
|
||||
gracefully.
|
||||
'''
|
||||
self.queue.put('QUIT')
|
||||
|
||||
@property
|
||||
def pid(self):
|
||||
return self.process.pid
|
||||
|
||||
@property
|
||||
def qsize(self):
|
||||
return self.queue.qsize()
|
||||
|
||||
@property
|
||||
def alive(self):
|
||||
return self.process.is_alive()
|
||||
|
||||
@property
|
||||
def mb(self):
|
||||
if self.alive:
|
||||
return '{:0.3f}'.format(
|
||||
psutil.Process(self.pid).memory_info().rss / 1024.0 / 1024.0
|
||||
)
|
||||
return '0'
|
||||
|
||||
@property
|
||||
def exitcode(self):
|
||||
return str(self.process.exitcode)
|
||||
|
||||
def calculate_managed_tasks(self):
|
||||
# look to see if any tasks were finished
|
||||
finished = []
|
||||
for _ in range(self.finished.qsize()):
|
||||
try:
|
||||
finished.append(self.finished.get(block=False))
|
||||
except QueueEmpty:
|
||||
break # qsize is not always _totally_ up to date
|
||||
|
||||
# if any tasks were finished, removed them from the managed tasks for
|
||||
# this worker
|
||||
for uuid in finished:
|
||||
self.messages_finished += 1
|
||||
del self.managed_tasks[uuid]
|
||||
|
||||
@property
|
||||
def current_task(self):
|
||||
self.calculate_managed_tasks()
|
||||
# the task at [0] is the one that's running right now (or is about to
|
||||
# be running)
|
||||
if len(self.managed_tasks):
|
||||
return self.managed_tasks[self.managed_tasks.keys()[0]]
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def orphaned_tasks(self):
|
||||
orphaned = []
|
||||
if not self.alive:
|
||||
# if this process had a running task that never finished,
|
||||
# requeue its error callbacks
|
||||
current_task = self.current_task
|
||||
if isinstance(current_task, dict):
|
||||
orphaned.extend(current_task.get('errbacks', []))
|
||||
|
||||
# if this process has any pending messages requeue them
|
||||
for _ in range(self.qsize):
|
||||
try:
|
||||
orphaned.append(self.queue.get(block=False))
|
||||
except QueueEmpty:
|
||||
break # qsize is not always _totally_ up to date
|
||||
if len(orphaned):
|
||||
logger.error(
|
||||
'requeuing {} messages from gone worker pid:{}'.format(
|
||||
len(orphaned), self.pid
|
||||
)
|
||||
)
|
||||
return orphaned
|
||||
|
||||
@property
|
||||
def busy(self):
|
||||
self.calculate_managed_tasks()
|
||||
return len(self.managed_tasks) > 0
|
||||
|
||||
@property
|
||||
def idle(self):
|
||||
return not self.busy
|
||||
|
||||
|
||||
class WorkerPool(object):
|
||||
'''
|
||||
Creates a pool of forked PoolWorkers.
|
||||
|
||||
As WorkerPool.write(...) is called (generally, by a kombu consumer
|
||||
implementation when it receives an AMQP message), messages are passed to
|
||||
one of the multiprocessing Queues where some work can be done on them.
|
||||
|
||||
class MessagePrinter(awx.main.dispatch.worker.BaseWorker):
|
||||
|
||||
def perform_work(self, body):
|
||||
print body
|
||||
|
||||
pool = WorkerPool(min_workers=4) # spawn four worker processes
|
||||
pool.init_workers(MessagePrint().work_loop)
|
||||
pool.write(
|
||||
0, # preferred worker 0
|
||||
'Hello, World!'
|
||||
)
|
||||
'''
|
||||
|
||||
debug_meta = ''
|
||||
|
||||
def __init__(self, min_workers=None, queue_size=None):
|
||||
self.name = settings.CLUSTER_HOST_ID
|
||||
self.pid = os.getpid()
|
||||
self.min_workers = min_workers or settings.JOB_EVENT_WORKERS
|
||||
self.queue_size = queue_size or settings.JOB_EVENT_MAX_QUEUE_SIZE
|
||||
|
||||
# self.workers tracks the state of worker running worker processes:
|
||||
# [
|
||||
# (total_messages_consumed, multiprocessing.Queue, multiprocessing.Process),
|
||||
# (total_messages_consumed, multiprocessing.Queue, multiprocessing.Process),
|
||||
# (total_messages_consumed, multiprocessing.Queue, multiprocessing.Process),
|
||||
# (total_messages_consumed, multiprocessing.Queue, multiprocessing.Process)
|
||||
# ]
|
||||
self.workers = []
|
||||
|
||||
def __len__(self):
|
||||
return len(self.workers)
|
||||
|
||||
def init_workers(self, target, *target_args):
|
||||
def shutdown_handler(active_workers):
|
||||
def _handler(signum, frame):
|
||||
logger.debug('received shutdown {}'.format(signame(signum)))
|
||||
try:
|
||||
for active_worker in active_workers:
|
||||
logger.debug('terminating worker')
|
||||
signal.signal(signum, signal.SIG_DFL)
|
||||
os.kill(os.getpid(), signum) # Rethrow signal, this time without catching it
|
||||
except Exception:
|
||||
logger.exception('error in shutdown_handler')
|
||||
return _handler
|
||||
self.target = target
|
||||
self.target_args = target_args
|
||||
for idx in range(self.min_workers):
|
||||
self.up()
|
||||
|
||||
def up(self):
|
||||
idx = len(self.workers)
|
||||
# It's important to close these because we're _about_ to fork, and we
|
||||
# don't want the forked processes to inherit the open sockets
|
||||
# for the DB and memcached connections (that way lies race conditions)
|
||||
django_connection.close()
|
||||
django_cache.close()
|
||||
for idx in range(self.min_workers):
|
||||
queue_actual = MPQueue(self.queue_size)
|
||||
w = Process(target=target, args=(queue_actual, idx,) + target_args)
|
||||
w.start()
|
||||
logger.debug('started {}[{}]'.format(target.im_self.__class__.__name__, idx))
|
||||
self.workers.append([0, queue_actual, w])
|
||||
worker = PoolWorker(self.queue_size, self.target, (idx,) + self.target_args)
|
||||
self.workers.append(worker)
|
||||
try:
|
||||
worker.start()
|
||||
except Exception:
|
||||
logger.exception('could not fork')
|
||||
else:
|
||||
logger.warn('scaling up worker pid:{}'.format(worker.pid))
|
||||
return idx, worker
|
||||
|
||||
signal.signal(signal.SIGINT, shutdown_handler([p[2] for p in self.workers]))
|
||||
signal.signal(signal.SIGTERM, shutdown_handler([p[2] for p in self.workers]))
|
||||
def debug(self, *args, **kwargs):
|
||||
self.cleanup()
|
||||
tmpl = Template(
|
||||
'{{ pool.name }}[pid:{{ pool.pid }}] workers total={{ workers|length }} {{ meta }} \n'
|
||||
'{% for w in workers %}'
|
||||
'. worker[pid:{{ w.pid }}]{% if not w.alive %} GONE exit={{ w.exitcode }}{% endif %}'
|
||||
' sent={{ w.messages_sent }}'
|
||||
' finished={{ w.messages_finished }}'
|
||||
' qsize={{ w.managed_tasks|length }}'
|
||||
' rss={{ w.mb }}MB'
|
||||
'{% for task in w.managed_tasks.values() %}'
|
||||
'\n - {% if loop.index0 == 0 %}running {% else %}queued {% endif %}'
|
||||
'{{ task["uuid"] }} '
|
||||
'{% if "task" in task %}'
|
||||
'{{ task["task"].rsplit(".", 1)[-1] }}'
|
||||
# don't print kwargs, they often contain launch-time secrets
|
||||
'(*{{ task.get("args", []) }})'
|
||||
'{% endif %}'
|
||||
'{% endfor %}'
|
||||
'{% if not w.managed_tasks|length %}'
|
||||
' [IDLE]'
|
||||
'{% endif %}'
|
||||
'\n'
|
||||
'{% endfor %}'
|
||||
)
|
||||
return tmpl.render(pool=self, workers=self.workers, meta=self.debug_meta)
|
||||
|
||||
def write(self, preferred_queue, body):
|
||||
queue_order = sorted(range(self.min_workers), cmp=lambda x, y: -1 if x==preferred_queue else 0)
|
||||
queue_order = sorted(range(len(self.workers)), cmp=lambda x, y: -1 if x==preferred_queue else 0)
|
||||
write_attempt_order = []
|
||||
for queue_actual in queue_order:
|
||||
try:
|
||||
worker_actual = self.workers[queue_actual]
|
||||
worker_actual[1].put(body, block=True, timeout=5)
|
||||
logger.debug('delivered to Worker[{}] qsize {}'.format(
|
||||
queue_actual, worker_actual[1].qsize()
|
||||
))
|
||||
worker_actual[0] += 1
|
||||
self.workers[queue_actual].put(body)
|
||||
return queue_actual
|
||||
except QueueFull:
|
||||
pass
|
||||
@@ -87,11 +266,113 @@ class WorkerPool(object):
|
||||
logger.warn("could not write payload to any queue, attempted order: {}".format(write_attempt_order))
|
||||
return None
|
||||
|
||||
def stop(self):
|
||||
for worker in self.workers:
|
||||
messages, queue, process = worker
|
||||
try:
|
||||
os.kill(process.pid, signal.SIGTERM)
|
||||
except OSError as e:
|
||||
if e.errno != errno.ESRCH:
|
||||
raise
|
||||
def stop(self, signum):
|
||||
try:
|
||||
for worker in self.workers:
|
||||
os.kill(worker.pid, signum)
|
||||
except Exception:
|
||||
logger.exception('could not kill {}'.format(worker.pid))
|
||||
|
||||
|
||||
class AutoscalePool(WorkerPool):
|
||||
'''
|
||||
An extended pool implementation that automatically scales workers up and
|
||||
down based on demand
|
||||
'''
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.max_workers = kwargs.pop('max_workers', None)
|
||||
super(AutoscalePool, self).__init__(*args, **kwargs)
|
||||
|
||||
if self.max_workers is None:
|
||||
settings_absmem = getattr(settings, 'SYSTEM_TASK_ABS_MEM', None)
|
||||
if settings_absmem is not None:
|
||||
total_memory_gb = int(settings_absmem)
|
||||
else:
|
||||
total_memory_gb = (psutil.virtual_memory().total >> 30) + 1 # noqa: round up
|
||||
# 5 workers per GB of total memory
|
||||
self.max_workers = (total_memory_gb * 5)
|
||||
|
||||
@property
|
||||
def should_grow(self):
|
||||
if len(self.workers) < self.min_workers:
|
||||
# If we don't have at least min_workers, add more
|
||||
return True
|
||||
# If every worker is busy doing something, add more
|
||||
return all([w.busy for w in self.workers])
|
||||
|
||||
@property
|
||||
def full(self):
|
||||
return len(self.workers) == self.max_workers
|
||||
|
||||
@property
|
||||
def debug_meta(self):
|
||||
return 'min={} max={}'.format(self.min_workers, self.max_workers)
|
||||
|
||||
def cleanup(self):
|
||||
"""
|
||||
Perform some internal account and cleanup. This is run on
|
||||
every cluster node heartbeat:
|
||||
|
||||
1. Discover worker processes that exited, and recover messages they
|
||||
were handling.
|
||||
2. Clean up unnecessary, idle workers.
|
||||
"""
|
||||
orphaned = []
|
||||
for w in self.workers[::]:
|
||||
if not w.alive:
|
||||
# the worker process has exited
|
||||
# 1. take the task it was running and enqueue the error
|
||||
# callbacks
|
||||
# 2. take any pending tasks delivered to its queue and
|
||||
# send them to another worker
|
||||
logger.error('worker pid:{} is gone (exit={})'.format(w.pid, w.exitcode))
|
||||
if w.current_task:
|
||||
try:
|
||||
for j in UnifiedJob.objects.filter(celery_task_id=w.current_task['uuid']):
|
||||
reaper.reap_job(j, 'failed')
|
||||
except Exception:
|
||||
logger.exception('failed to reap job UUID {}'.format(w.current_task['uuid']))
|
||||
orphaned.extend(w.orphaned_tasks)
|
||||
self.workers.remove(w)
|
||||
elif w.idle and len(self.workers) > self.min_workers:
|
||||
# the process has an empty queue (it's idle) and we have
|
||||
# more processes in the pool than we need (> min)
|
||||
# send this process a message so it will exit gracefully
|
||||
# at the next opportunity
|
||||
logger.warn('scaling down worker pid:{}'.format(w.pid))
|
||||
w.quit()
|
||||
self.workers.remove(w)
|
||||
|
||||
for m in orphaned:
|
||||
# if all the workers are dead, spawn at least one
|
||||
if not len(self.workers):
|
||||
self.up()
|
||||
idx = random.choice(range(len(self.workers)))
|
||||
self.write(idx, m)
|
||||
|
||||
def up(self):
|
||||
if self.full:
|
||||
# if we can't spawn more workers, just toss this message into a
|
||||
# random worker's backlog
|
||||
idx = random.choice(range(len(self.workers)))
|
||||
return idx, self.workers[idx]
|
||||
else:
|
||||
return super(AutoscalePool, self).up()
|
||||
|
||||
def write(self, preferred_queue, body):
|
||||
# when the cluster heartbeat occurs, clean up internally
|
||||
if isinstance(body, dict) and 'cluster_node_heartbeat' in body['task']:
|
||||
self.cleanup()
|
||||
if self.should_grow:
|
||||
self.up()
|
||||
# we don't care about "preferred queue" round robin distribution, just
|
||||
# find the first non-busy worker and claim it
|
||||
workers = self.workers[:]
|
||||
random.shuffle(workers)
|
||||
for w in workers:
|
||||
if not w.busy:
|
||||
w.put(body)
|
||||
break
|
||||
else:
|
||||
return super(AutoscalePool, self).write(preferred_queue, body)
|
||||
|
||||
128
awx/main/dispatch/publish.py
Normal file
128
awx/main/dispatch/publish.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
from uuid import uuid4
|
||||
|
||||
from django.conf import settings
|
||||
from kombu import Connection, Exchange, Producer
|
||||
|
||||
logger = logging.getLogger('awx.main.dispatch')
|
||||
|
||||
|
||||
def serialize_task(f):
|
||||
return '.'.join([f.__module__, f.__name__])
|
||||
|
||||
|
||||
class task:
|
||||
"""
|
||||
Used to decorate a function or class so that it can be run asynchronously
|
||||
via the task dispatcher. Tasks can be simple functions:
|
||||
|
||||
@task()
|
||||
def add(a, b):
|
||||
return a + b
|
||||
|
||||
...or classes that define a `run` method:
|
||||
|
||||
@task()
|
||||
class Adder:
|
||||
def run(self, a, b):
|
||||
return a + b
|
||||
|
||||
# Tasks can be run synchronously...
|
||||
assert add(1, 1) == 2
|
||||
assert Adder().run(1, 1) == 2
|
||||
|
||||
# ...or published to a queue:
|
||||
add.apply_async([1, 1])
|
||||
Adder.apply_async([1, 1])
|
||||
|
||||
# Tasks can also define a specific target queue or exchange type:
|
||||
|
||||
@task(queue='slow-tasks')
|
||||
def snooze():
|
||||
time.sleep(10)
|
||||
|
||||
@task(queue='tower_broadcast', exchange_type='fanout')
|
||||
def announce():
|
||||
print "Run this everywhere!"
|
||||
"""
|
||||
|
||||
def __init__(self, queue=None, exchange_type=None):
|
||||
self.queue = queue
|
||||
self.exchange_type = exchange_type
|
||||
|
||||
def __call__(self, fn=None):
|
||||
queue = self.queue
|
||||
exchange_type = self.exchange_type
|
||||
|
||||
class PublisherMixin(object):
|
||||
|
||||
queue = None
|
||||
|
||||
@classmethod
|
||||
def delay(cls, *args, **kwargs):
|
||||
return cls.apply_async(args, kwargs)
|
||||
|
||||
@classmethod
|
||||
def apply_async(cls, args=None, kwargs=None, queue=None, uuid=None, **kw):
|
||||
task_id = uuid or str(uuid4())
|
||||
args = args or []
|
||||
kwargs = kwargs or {}
|
||||
queue = (
|
||||
queue or
|
||||
getattr(cls.queue, 'im_func', cls.queue) or
|
||||
settings.CELERY_DEFAULT_QUEUE
|
||||
)
|
||||
obj = {
|
||||
'uuid': task_id,
|
||||
'args': args,
|
||||
'kwargs': kwargs,
|
||||
'task': cls.name
|
||||
}
|
||||
obj.update(**kw)
|
||||
if callable(queue):
|
||||
queue = queue()
|
||||
if not settings.IS_TESTING(sys.argv):
|
||||
with Connection(settings.BROKER_URL) as conn:
|
||||
exchange = Exchange(queue, type=exchange_type or 'direct')
|
||||
producer = Producer(conn)
|
||||
logger.debug('publish {}({}, queue={})'.format(
|
||||
cls.name,
|
||||
task_id,
|
||||
queue
|
||||
))
|
||||
producer.publish(obj,
|
||||
serializer='json',
|
||||
compression='bzip2',
|
||||
exchange=exchange,
|
||||
declare=[exchange],
|
||||
delivery_mode="persistent",
|
||||
routing_key=queue)
|
||||
return (obj, queue)
|
||||
|
||||
# If the object we're wrapping *is* a class (e.g., RunJob), return
|
||||
# a *new* class that inherits from the wrapped class *and* BaseTask
|
||||
# In this way, the new class returned by our decorator is the class
|
||||
# being decorated *plus* PublisherMixin so cls.apply_async() and
|
||||
# cls.delay() work
|
||||
bases = []
|
||||
ns = {'name': serialize_task(fn), 'queue': queue}
|
||||
if inspect.isclass(fn):
|
||||
bases = list(fn.__bases__)
|
||||
ns.update(fn.__dict__)
|
||||
cls = type(
|
||||
fn.__name__,
|
||||
tuple(bases + [PublisherMixin]),
|
||||
ns
|
||||
)
|
||||
if inspect.isclass(fn):
|
||||
return cls
|
||||
|
||||
# if the object being decorated is *not* a class (it's a Python
|
||||
# function), make fn.apply_async and fn.delay proxy through to the
|
||||
# PublisherMixin we dynamically created above
|
||||
setattr(fn, 'name', cls.name)
|
||||
setattr(fn, 'apply_async', cls.apply_async)
|
||||
setattr(fn, 'delay', cls.delay)
|
||||
return fn
|
||||
46
awx/main/dispatch/reaper.py
Normal file
46
awx/main/dispatch/reaper.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from datetime import timedelta
|
||||
import logging
|
||||
|
||||
from django.db.models import Q
|
||||
from django.utils.timezone import now as tz_now
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
|
||||
from awx.main.models import Instance, UnifiedJob, WorkflowJob
|
||||
|
||||
logger = logging.getLogger('awx.main.dispatch')
|
||||
|
||||
|
||||
def reap_job(j, status):
|
||||
j.status = status
|
||||
j.start_args = '' # blank field to remove encrypted passwords
|
||||
j.job_explanation += ' '.join((
|
||||
'Task was marked as running in Tower but was not present in',
|
||||
'the job queue, so it has been marked as failed.',
|
||||
))
|
||||
j.save(update_fields=['status', 'start_args', 'job_explanation'])
|
||||
if hasattr(j, 'send_notification_templates'):
|
||||
j.send_notification_templates('failed')
|
||||
j.websocket_emit_status(status)
|
||||
logger.error(
|
||||
'{} is no longer running; reaping'.format(j.log_format)
|
||||
)
|
||||
|
||||
|
||||
def reap(instance=None, status='failed'):
|
||||
'''
|
||||
Reap all jobs in waiting|running for this instance.
|
||||
'''
|
||||
me = instance or Instance.objects.me()
|
||||
now = tz_now()
|
||||
workflow_ctype_id = ContentType.objects.get_for_model(WorkflowJob).id
|
||||
jobs = UnifiedJob.objects.filter(
|
||||
(
|
||||
Q(status='running') |
|
||||
Q(status='waiting', modified__lte=now - timedelta(seconds=60))
|
||||
) & (
|
||||
Q(execution_node=me.hostname) |
|
||||
Q(controller_node=me.hostname)
|
||||
) & ~Q(polymorphic_ctype_id=workflow_ctype_id)
|
||||
)
|
||||
for j in jobs:
|
||||
reap_job(j, status)
|
||||
3
awx/main/dispatch/worker/__init__.py
Normal file
3
awx/main/dispatch/worker/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .base import AWXConsumer, BaseWorker # noqa
|
||||
from .callback import CallbackBrokerWorker # noqa
|
||||
from .task import TaskWorker # noqa
|
||||
146
awx/main/dispatch/worker/base.py
Normal file
146
awx/main/dispatch/worker/base.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# Copyright (c) 2018 Ansible by Red Hat
|
||||
# All Rights Reserved.
|
||||
|
||||
import os
|
||||
import logging
|
||||
import signal
|
||||
from uuid import UUID
|
||||
from Queue import Empty as QueueEmpty
|
||||
|
||||
from kombu import Producer
|
||||
from kombu.mixins import ConsumerMixin
|
||||
|
||||
from awx.main.dispatch.pool import WorkerPool
|
||||
|
||||
logger = logging.getLogger('awx.main.dispatch')
|
||||
|
||||
|
||||
def signame(sig):
|
||||
return dict(
|
||||
(k, v) for v, k in signal.__dict__.items()
|
||||
if v.startswith('SIG') and not v.startswith('SIG_')
|
||||
)[sig]
|
||||
|
||||
|
||||
class WorkerSignalHandler:
|
||||
|
||||
def __init__(self):
|
||||
self.kill_now = False
|
||||
signal.signal(signal.SIGINT, self.exit_gracefully)
|
||||
|
||||
def exit_gracefully(self, *args, **kwargs):
|
||||
self.kill_now = True
|
||||
|
||||
|
||||
class AWXConsumer(ConsumerMixin):
|
||||
|
||||
def __init__(self, name, connection, worker, queues=[], pool=None):
|
||||
self.connection = connection
|
||||
self.total_messages = 0
|
||||
self.queues = queues
|
||||
self.worker = worker
|
||||
self.pool = pool
|
||||
if pool is None:
|
||||
self.pool = WorkerPool()
|
||||
self.pool.init_workers(self.worker.work_loop)
|
||||
|
||||
def get_consumers(self, Consumer, channel):
|
||||
logger.debug(self.listening_on)
|
||||
return [Consumer(queues=self.queues, accept=['json'],
|
||||
callbacks=[self.process_task])]
|
||||
|
||||
@property
|
||||
def listening_on(self):
|
||||
return 'listening on {}'.format([
|
||||
'{} [{}]'.format(q.name, q.exchange.type) for q in self.queues
|
||||
])
|
||||
|
||||
def control(self, body, message):
|
||||
logger.warn(body)
|
||||
control = body.get('control')
|
||||
if control in ('status', 'running'):
|
||||
producer = Producer(
|
||||
channel=self.connection,
|
||||
routing_key=message.properties['reply_to']
|
||||
)
|
||||
if control == 'status':
|
||||
msg = '\n'.join([self.listening_on, self.pool.debug()])
|
||||
elif control == 'running':
|
||||
msg = []
|
||||
for worker in self.pool.workers:
|
||||
worker.calculate_managed_tasks()
|
||||
msg.extend(worker.managed_tasks.keys())
|
||||
producer.publish(msg)
|
||||
elif control == 'reload':
|
||||
for worker in self.pool.workers:
|
||||
worker.quit()
|
||||
else:
|
||||
logger.error('unrecognized control message: {}'.format(control))
|
||||
message.ack()
|
||||
|
||||
def process_task(self, body, message):
|
||||
if 'control' in body:
|
||||
return self.control(body, message)
|
||||
if len(self.pool):
|
||||
if "uuid" in body and body['uuid']:
|
||||
try:
|
||||
queue = UUID(body['uuid']).int % len(self.pool)
|
||||
except Exception:
|
||||
queue = self.total_messages % len(self.pool)
|
||||
else:
|
||||
queue = self.total_messages % len(self.pool)
|
||||
else:
|
||||
queue = 0
|
||||
self.pool.write(queue, body)
|
||||
self.total_messages += 1
|
||||
message.ack()
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
signal.signal(signal.SIGINT, self.stop)
|
||||
signal.signal(signal.SIGTERM, self.stop)
|
||||
self.worker.on_start()
|
||||
super(AWXConsumer, self).run(*args, **kwargs)
|
||||
|
||||
def stop(self, signum, frame):
|
||||
self.should_stop = True # this makes the kombu mixin stop consuming
|
||||
logger.debug('received {}, stopping'.format(signame(signum)))
|
||||
self.worker.on_stop()
|
||||
raise SystemExit()
|
||||
|
||||
|
||||
class BaseWorker(object):
|
||||
|
||||
def work_loop(self, queue, finished, idx, *args):
|
||||
ppid = os.getppid()
|
||||
signal_handler = WorkerSignalHandler()
|
||||
while not signal_handler.kill_now:
|
||||
# if the parent PID changes, this process has been orphaned
|
||||
# via e.g., segfault or sigkill, we should exit too
|
||||
if os.getppid() != ppid:
|
||||
break
|
||||
try:
|
||||
body = queue.get(block=True, timeout=1)
|
||||
if body == 'QUIT':
|
||||
break
|
||||
except QueueEmpty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error("Exception on worker, restarting: " + str(e))
|
||||
continue
|
||||
try:
|
||||
self.perform_work(body, *args)
|
||||
finally:
|
||||
if 'uuid' in body:
|
||||
uuid = body['uuid']
|
||||
logger.debug('task {} is finished'.format(uuid))
|
||||
finished.put(uuid)
|
||||
logger.warn('worker exiting gracefully pid:{}'.format(os.getpid()))
|
||||
|
||||
def perform_work(self, body):
|
||||
raise NotImplementedError()
|
||||
|
||||
def on_start(self):
|
||||
pass
|
||||
|
||||
def on_stop(self):
|
||||
pass
|
||||
@@ -1,83 +1,30 @@
|
||||
# Copyright (c) 2018 Ansible by Red Hat
|
||||
# All Rights Reserved.
|
||||
|
||||
import logging
|
||||
import time
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
import traceback
|
||||
from uuid import UUID
|
||||
from Queue import Empty as QueueEmpty
|
||||
|
||||
from kombu.mixins import ConsumerMixin
|
||||
from django.conf import settings
|
||||
from django.db import DatabaseError, OperationalError, connection as django_connection
|
||||
from django.db.utils import InterfaceError, InternalError
|
||||
|
||||
from awx.main.consumers import emit_channel_notification
|
||||
from awx.main.models import (JobEvent, AdHocCommandEvent, ProjectUpdateEvent,
|
||||
InventoryUpdateEvent, SystemJobEvent, UnifiedJob)
|
||||
from awx.main.consumers import emit_channel_notification
|
||||
from awx.main.dispatch.pool import WorkerPool
|
||||
|
||||
from .base import BaseWorker
|
||||
|
||||
logger = logging.getLogger('awx.main.dispatch')
|
||||
|
||||
|
||||
class WorkerSignalHandler:
|
||||
|
||||
def __init__(self):
|
||||
self.kill_now = False
|
||||
signal.signal(signal.SIGINT, self.exit_gracefully)
|
||||
signal.signal(signal.SIGTERM, self.exit_gracefully)
|
||||
|
||||
def exit_gracefully(self, *args, **kwargs):
|
||||
self.kill_now = True
|
||||
|
||||
|
||||
class AWXConsumer(ConsumerMixin):
|
||||
|
||||
def __init__(self, connection, worker, queues=[]):
|
||||
self.connection = connection
|
||||
self.total_messages = 0
|
||||
self.queues = queues
|
||||
self.pool = WorkerPool()
|
||||
self.pool.init_workers(worker.work_loop)
|
||||
|
||||
def get_consumers(self, Consumer, channel):
|
||||
return [Consumer(queues=self.queues, accept=['json'],
|
||||
callbacks=[self.process_task])]
|
||||
|
||||
def process_task(self, body, message):
|
||||
if "uuid" in body and body['uuid']:
|
||||
try:
|
||||
queue = UUID(body['uuid']).int % len(self.pool)
|
||||
except Exception:
|
||||
queue = self.total_messages % len(self.pool)
|
||||
else:
|
||||
queue = self.total_messages % len(self.pool)
|
||||
self.pool.write(queue, body)
|
||||
self.total_messages += 1
|
||||
message.ack()
|
||||
|
||||
|
||||
class BaseWorker(object):
|
||||
|
||||
def work_loop(self, queue, idx, *args):
|
||||
signal_handler = WorkerSignalHandler()
|
||||
while not signal_handler.kill_now:
|
||||
try:
|
||||
body = queue.get(block=True, timeout=1)
|
||||
except QueueEmpty:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error("Exception on worker, restarting: " + str(e))
|
||||
continue
|
||||
self.perform_work(body, *args)
|
||||
|
||||
def perform_work(self, body):
|
||||
raise NotImplemented()
|
||||
|
||||
|
||||
class CallbackBrokerWorker(BaseWorker):
|
||||
'''
|
||||
A worker implementation that deserializes callback event data and persists
|
||||
it into the database.
|
||||
|
||||
The code that *builds* these types of messages is found in the AWX display
|
||||
callback (`awx.lib.awx_display_callback`).
|
||||
'''
|
||||
|
||||
MAX_RETRIES = 2
|
||||
|
||||
@@ -151,7 +98,7 @@ class CallbackBrokerWorker(BaseWorker):
|
||||
try:
|
||||
_save_event_data()
|
||||
break
|
||||
except (OperationalError, InterfaceError, InternalError) as e:
|
||||
except (OperationalError, InterfaceError, InternalError):
|
||||
if retries >= self.MAX_RETRIES:
|
||||
logger.exception('Worker could not re-establish database connectivity, shutting down gracefully: Job {}'.format(job_identifier))
|
||||
os.kill(os.getppid(), signal.SIGINT)
|
||||
@@ -164,7 +111,7 @@ class CallbackBrokerWorker(BaseWorker):
|
||||
django_connection.close()
|
||||
time.sleep(delay)
|
||||
retries += 1
|
||||
except DatabaseError as e:
|
||||
except DatabaseError:
|
||||
logger.exception('Database Error Saving Job Event for Job {}'.format(job_identifier))
|
||||
break
|
||||
except Exception as exc:
|
||||
113
awx/main/dispatch/worker/task.py
Normal file
113
awx/main/dispatch/worker/task.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import inspect
|
||||
import logging
|
||||
import importlib
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import six
|
||||
|
||||
from awx.main.tasks import dispatch_startup, inform_cluster_of_shutdown
|
||||
|
||||
from .base import BaseWorker
|
||||
|
||||
logger = logging.getLogger('awx.main.dispatch')
|
||||
|
||||
|
||||
class TaskWorker(BaseWorker):
|
||||
'''
|
||||
A worker implementation that deserializes task messages and runs native
|
||||
Python code.
|
||||
|
||||
The code that *builds* these types of messages is found in
|
||||
`awx.main.dispatch.publish`.
|
||||
'''
|
||||
|
||||
@classmethod
|
||||
def resolve_callable(cls, task):
|
||||
'''
|
||||
Transform a dotted notation task into an imported, callable function, e.g.,
|
||||
|
||||
awx.main.tasks.delete_inventory
|
||||
awx.main.tasks.RunProjectUpdate
|
||||
'''
|
||||
module, target = task.rsplit('.', 1)
|
||||
module = importlib.import_module(module)
|
||||
_call = None
|
||||
if hasattr(module, target):
|
||||
_call = getattr(module, target, None)
|
||||
return _call
|
||||
|
||||
def run_callable(self, body):
|
||||
'''
|
||||
Given some AMQP message, import the correct Python code and run it.
|
||||
'''
|
||||
task = body['task']
|
||||
uuid = body.get('uuid', '<unknown>')
|
||||
args = body.get('args', [])
|
||||
kwargs = body.get('kwargs', {})
|
||||
_call = TaskWorker.resolve_callable(task)
|
||||
if inspect.isclass(_call):
|
||||
# the callable is a class, e.g., RunJob; instantiate and
|
||||
# return its `run()` method
|
||||
_call = _call().run
|
||||
# don't print kwargs, they often contain launch-time secrets
|
||||
logger.debug('task {} starting {}(*{})'.format(uuid, task, args))
|
||||
return _call(*args, **kwargs)
|
||||
|
||||
def perform_work(self, body):
|
||||
'''
|
||||
Import and run code for a task e.g.,
|
||||
|
||||
body = {
|
||||
'args': [8],
|
||||
'callbacks': [{
|
||||
'args': [],
|
||||
'kwargs': {}
|
||||
'task': u'awx.main.tasks.handle_work_success'
|
||||
}],
|
||||
'errbacks': [{
|
||||
'args': [],
|
||||
'kwargs': {},
|
||||
'task': 'awx.main.tasks.handle_work_error'
|
||||
}],
|
||||
'kwargs': {},
|
||||
'task': u'awx.main.tasks.RunProjectUpdate'
|
||||
}
|
||||
'''
|
||||
result = None
|
||||
try:
|
||||
result = self.run_callable(body)
|
||||
except Exception as exc:
|
||||
|
||||
try:
|
||||
if getattr(exc, 'is_awx_task_error', False):
|
||||
# Error caused by user / tracked in job output
|
||||
logger.warning(six.text_type("{}").format(exc))
|
||||
else:
|
||||
task = body['task']
|
||||
args = body.get('args', [])
|
||||
kwargs = body.get('kwargs', {})
|
||||
logger.exception('Worker failed to run task {}(*{}, **{}'.format(
|
||||
task, args, kwargs
|
||||
))
|
||||
except Exception:
|
||||
# It's fairly critical that this code _not_ raise exceptions on logging
|
||||
# If you configure external logging in a way that _it_ fails, there's
|
||||
# not a lot we can do here; sys.stderr.write is a final hail mary
|
||||
_, _, tb = sys.exc_info()
|
||||
traceback.print_tb(tb)
|
||||
|
||||
for callback in body.get('errbacks', []) or []:
|
||||
callback['uuid'] = body['uuid']
|
||||
self.perform_work(callback)
|
||||
|
||||
for callback in body.get('callbacks', []) or []:
|
||||
callback['uuid'] = body['uuid']
|
||||
self.perform_work(callback)
|
||||
return result
|
||||
|
||||
def on_start(self):
|
||||
dispatch_startup()
|
||||
|
||||
def on_stop(self):
|
||||
inform_cluster_of_shutdown()
|
||||
Reference in New Issue
Block a user