#Copyright (c) 2014 Ansible, Inc. # All Rights Reserved # Python import os import datetime import logging import json import signal import time from optparse import make_option from multiprocessing import Process # Django from django.conf import settings from django.core.management.base import NoArgsCommand, CommandError from django.db import transaction, DatabaseError from django.contrib.auth.models import User from django.utils.dateparse import parse_datetime from django.utils.timezone import now, is_aware, make_aware from django.utils.tzinfo import FixedOffset # AWX from awx.main.models import * from awx.main.tasks import handle_work_error from awx.main.utils import get_system_task_capacity, decrypt_field # ZeroMQ import zmq # Celery from celery.task.control import inspect class SimpleDAG(object): ''' A simple implementation of a directed acyclic graph ''' def __init__(self): self.nodes = [] self.edges = [] def __contains__(self, obj): for node in self.nodes: if node['node_object'] == obj: return True return False def __len__(self): return len(self.nodes) def __iter__(self): return self.nodes.__iter__() def generate_graphviz_plot(self): def short_string_obj(obj): if type(obj) == Job: type_str = "Job" elif type(obj) == InventoryUpdate: type_str = "Inventory" elif type(obj) == ProjectUpdate: type_str = "Project" else: type_str = "Unknown" type_str += "-%s" % str(obj.id) return type_str doc = """ digraph g { rankdir = LR """ for n in self.nodes: doc += "%s [color = %s]\n" % (short_string_obj(n['node_object']), "red" if n['node_object'].status == 'running' else "black") for from_node, to_node in self.edges: doc += "%s -> %s;\n" % (short_string_obj(self.nodes[from_node]['node_object']), short_string_obj(self.nodes[to_node]['node_object'])) doc += "}" gv_file = open('/tmp/graph.gv', 'w') gv_file.write(doc) gv_file.close() def add_node(self, obj, metadata=None): if self.find_ord(obj) is None: self.nodes.append(dict(node_object=obj, metadata=metadata)) def add_edge(self, from_obj, to_obj): from_obj_ord = self.find_ord(from_obj) to_obj_ord = self.find_ord(to_obj) if from_obj_ord is None or to_obj_ord is None: raise LookupError("Object not found") self.edges.append((from_obj_ord, to_obj_ord)) def add_edges(self, edgelist): for edge_pair in edgelist: self.add_edge(edge_pair[0], edge_pair[1]) def find_ord(self, obj): for idx in range(len(self.nodes)): if obj == self.nodes[idx]['node_object']: return idx return None def get_node_type(self, obj): if type(obj) == Job: return "ansible_playbook" elif type(obj) == InventoryUpdate: return "inventory_update" elif type(obj) == ProjectUpdate: return "project_update" return "unknown" def get_dependencies(self, obj): antecedents = [] this_ord = self.find_ord(obj) for node, dep in self.edges: if node == this_ord: antecedents.append(self.nodes[dep]) return antecedents def get_dependents(self, obj): decendents = [] this_ord = self.find_ord(obj) for node, dep in self.edges: if dep == this_ord: decendents.append(self.nodes[node]) return decendents def get_leaf_nodes(self): leafs = [] for n in self.nodes: if len(self.get_dependencies(n['node_object'])) < 1: leafs.append(n) return leafs def get_tasks(): ''' Fetch all Tower tasks that are relevant to the task management system ''' # TODO: Replace this when we can grab all objects in a sane way graph_jobs = [j for j in Job.objects.filter(status__in=('new', 'waiting', 'pending', 'running'))] graph_inventory_updates = [iu for iu in InventoryUpdate.objects.filter(status__in=('new', 'waiting', 'pending', 'running'))] graph_project_updates = [pu for pu in ProjectUpdate.objects.filter(status__in=('new', 'waiting', 'pending', 'running'))] all_actions = sorted(graph_jobs + graph_inventory_updates + graph_project_updates, key=lambda task: task.created) return all_actions def rebuild_graph(message): ''' Regenerate the task graph by refreshing known tasks from Tower, purging orphaned running tasks, and creatingdependencies for new tasks before generating directed edge relationships between those tasks ''' inspector = inspect() active_task_queues = inspector.active() active_tasks = [] for queue in active_task_queues: active_tasks += [at['id'] for at in active_task_queues[queue]] all_sorted_tasks = get_tasks() if not len(all_sorted_tasks): return None running_tasks = filter(lambda t: t.status == 'running', all_sorted_tasks) waiting_tasks = filter(lambda t: t.status != 'running', all_sorted_tasks) new_tasks = filter(lambda t: t.status == 'new', all_sorted_tasks) # Check running tasks and make sure they are active in celery if settings.DEBUG: print("Active celery tasks: " + str(active_tasks)) for task in list(running_tasks): if task.celery_task_id not in active_tasks: # Pull status again and make sure it didn't finish in the meantime task.status = 'failed' task.result_traceback += "Task was marked as running in Tower but was not present in Celery so it has been marked as failed" task.save() running_tasks.pop(task) if settings.DEBUG: print("Task %s appears orphaned... marking as failed" % task) # Create and process dependencies for new tasks for task in new_tasks: if settings.DEBUG: print("Checking dependencies for: %s" % str(task)) task_dependencies = task.generate_dependencies(running_tasks + waiting_tasks) #TODO: other 'new' tasks? Need to investigate this scenario if settings.DEBUG: print("New dependencies: %s" % str(task_dependencies)) for dep in task_dependencies: # We recalculate the created time for the moment to ensure the dependencies are always sorted in the right order relative to the dependent task time_delt = len(task_dependencies) - task_dependencies.index(dep) dep.created = task.created - datetime.timedelta(seconds=1+time_delt) dep.status = 'waiting' dep.save() waiting_tasks.insert(waiting_tasks.index(task), dep) task.status = 'waiting' task.save() # Rebuild graph graph = SimpleDAG() print("Graph nodes: " + str(graph.nodes)) for task in running_tasks: if settings.DEBUG: print("Adding running task: %s to graph" % str(task)) graph.add_node(task) if settings.DEBUG: print("Waiting Tasks: %s" % str(waiting_tasks)) for wait_task in waiting_tasks: node_dependencies = [] for node in graph: if wait_task.is_blocked_by(node['node_object']): if settings.DEBUG: print("Waiting task %s is blocked by %s" % (str(wait_task), node['node_object'])) node_dependencies.append(node['node_object']) graph.add_node(wait_task) for dependency in node_dependencies: graph.add_edge(wait_task, dependency) if settings.DEBUG: print("Graph Edges: %s" % str(graph.edges)) graph.generate_graphviz_plot() return graph def process_graph(graph, task_capacity): ''' Given a task dependency graph, start and manage tasks given their priority and weight ''' leaf_nodes = graph.get_leaf_nodes() running_nodes = filter(lambda x: x['node_object'].status == 'running', leaf_nodes) running_impact = sum([t['node_object'].task_impact for t in running_nodes]) ready_nodes = filter(lambda x: x['node_object'].status != 'running', leaf_nodes) remaining_volume = task_capacity - running_impact if settings.DEBUG: print("Running Nodes: %s; Capacity: %s; Running Impact: %s; Remaining Capacity: %s" % (str(running_nodes), str(task_capacity), str(running_impact), str(remaining_volume))) print("Ready Nodes: %s" % str(ready_nodes)) for task_node in ready_nodes: node_obj = task_node['node_object'] node_args = task_node['metadata'] impact = node_obj.task_impact if impact <= remaining_volume or running_impact == 0: dependent_nodes = [{'type': graph.get_node_type(n['node_object']), 'id': n['node_object'].id} for n in graph.get_dependents(node_obj)] error_handler = handle_work_error.s(subtasks=dependent_nodes) start_status = node_obj.start(error_callback=error_handler) if not start_status: node_obj.status = 'failed' node_obj.result_traceback += "Task failed pre-start check" # TODO: Run error handler continue remaining_volume -= impact running_impact += impact print("Started Node: %s (capacity hit: %s) Remaining Capacity: %s" % (str(node_obj), str(impact), str(remaining_volume))) def run_taskmanager(command_port): ''' Receive task start and finish signals to rebuild a dependency graph and manage the actual running of tasks ''' paused = False task_capacity = get_system_task_capacity() command_context = zmq.Context() command_socket = command_context.socket(zmq.REP) command_socket.bind(command_port) if settings.DEBUG: print("Listening on %s" % command_port) last_rebuild = datetime.datetime.fromtimestamp(0) while True: try: message = command_socket.recv_json(flags=zmq.NOBLOCK) command_socket.send("1") except zmq.error.ZMQError,e: message = None if message is not None or (datetime.datetime.now() - last_rebuild).seconds > 60: if message is not None and 'pause' in message: if settings.DEBUG: print("Pause command received: %s" % str(message)) paused = message['pause'] graph = rebuild_graph(message) if not paused and graph is not None: process_graph(graph, task_capacity) last_rebuild = datetime.datetime.now() time.sleep(0.1) class Command(NoArgsCommand): help = 'Launch the job graph runner' def init_logging(self): log_levels = dict(enumerate([logging.ERROR, logging.INFO, logging.DEBUG, 0])) self.logger = logging.getLogger('awx.main.commands.run_task_system') self.logger.setLevel(log_levels.get(self.verbosity, 0)) handler = logging.StreamHandler() handler.setFormatter(logging.Formatter('%(message)s')) self.logger.addHandler(handler) self.logger.propagate = False def handle_noargs(self, **options): self.verbosity = int(options.get('verbosity', 1)) self.init_logging() command_port = settings.TASK_COMMAND_PORT try: run_taskmanager(command_port) except KeyboardInterrupt: pass