move code linting to a stricter pep8-esque auto-formatting tool, black

This commit is contained in:
Ryan Petrello
2021-03-19 12:44:51 -04:00
parent 9b702e46fe
commit c2ef0a6500
671 changed files with 20538 additions and 21924 deletions

View File

@@ -4,6 +4,10 @@
# AWX
from awx.main.utils.common import * # noqa
from awx.main.utils.encryption import ( # noqa
get_encryption_key, encrypt_field, decrypt_field, encrypt_value,
decrypt_value, encrypt_dict,
get_encryption_key,
encrypt_field,
decrypt_field,
encrypt_value,
decrypt_value,
encrypt_dict,
)

View File

@@ -49,12 +49,7 @@ def could_be_playbook(project_path, dir_path, filename):
# show up.
matched = False
try:
for n, line in enumerate(codecs.open(
playbook_path,
'r',
encoding='utf-8',
errors='ignore'
)):
for n, line in enumerate(codecs.open(playbook_path, 'r', encoding='utf-8', errors='ignore')):
if valid_playbook_re.match(line):
matched = True
break
@@ -89,12 +84,7 @@ def could_be_inventory(project_path, dir_path, filename):
# Ansible inventory mainly
try:
# only read through first 10 lines for performance
with codecs.open(
inventory_path,
'r',
encoding='utf-8',
errors='ignore'
) as inv_file:
with codecs.open(inventory_path, 'r', encoding='utf-8', errors='ignore') as inv_file:
for line in islice(inv_file, 10):
if not valid_inventory_re.match(line):
return None

View File

@@ -23,10 +23,7 @@ from django.core.exceptions import ObjectDoesNotExist, FieldDoesNotExist
from django.utils.translation import ugettext_lazy as _
from django.utils.functional import cached_property
from django.db.models.fields.related import ForeignObjectRel, ManyToManyField
from django.db.models.fields.related_descriptors import (
ForwardManyToOneDescriptor,
ManyToManyDescriptor
)
from django.db.models.fields.related_descriptors import ForwardManyToOneDescriptor, ManyToManyDescriptor
from django.db.models.query import QuerySet
from django.db.models import Q
@@ -42,30 +39,65 @@ from awx.conf.license import get_license
logger = logging.getLogger('awx.main.utils')
__all__ = [
'get_object_or_400', 'camelcase_to_underscore', 'underscore_to_camelcase', 'memoize',
'memoize_delete', 'get_ansible_version', 'get_licenser', 'get_awx_http_client_headers',
'get_awx_version', 'update_scm_url', 'get_type_for_model', 'get_model_for_type',
'copy_model_by_class', 'copy_m2m_relationships',
'prefetch_page_capabilities', 'to_python_boolean', 'ignore_inventory_computed_fields',
'ignore_inventory_group_removal', '_inventory_updates', 'get_pk_from_dict', 'getattrd',
'getattr_dne', 'NoDefaultProvided', 'get_current_apps', 'set_current_apps',
'extract_ansible_vars', 'get_search_fields', 'get_system_task_capacity',
'get_cpu_capacity', 'get_mem_capacity', 'wrap_args_with_proot', 'build_proot_temp_dir',
'check_proot_installed', 'model_to_dict', 'NullablePromptPseudoField',
'model_instance_diff', 'parse_yaml_or_json', 'RequireDebugTrueOrTest',
'has_model_field_prefetched', 'set_environ', 'IllegalArgumentError',
'get_custom_venv_choices', 'get_external_account', 'task_manager_bulk_reschedule',
'schedule_task_manager', 'classproperty', 'create_temporary_fifo', 'truncate_stdout',
'deepmerge'
'get_object_or_400',
'camelcase_to_underscore',
'underscore_to_camelcase',
'memoize',
'memoize_delete',
'get_ansible_version',
'get_licenser',
'get_awx_http_client_headers',
'get_awx_version',
'update_scm_url',
'get_type_for_model',
'get_model_for_type',
'copy_model_by_class',
'copy_m2m_relationships',
'prefetch_page_capabilities',
'to_python_boolean',
'ignore_inventory_computed_fields',
'ignore_inventory_group_removal',
'_inventory_updates',
'get_pk_from_dict',
'getattrd',
'getattr_dne',
'NoDefaultProvided',
'get_current_apps',
'set_current_apps',
'extract_ansible_vars',
'get_search_fields',
'get_system_task_capacity',
'get_cpu_capacity',
'get_mem_capacity',
'wrap_args_with_proot',
'build_proot_temp_dir',
'check_proot_installed',
'model_to_dict',
'NullablePromptPseudoField',
'model_instance_diff',
'parse_yaml_or_json',
'RequireDebugTrueOrTest',
'has_model_field_prefetched',
'set_environ',
'IllegalArgumentError',
'get_custom_venv_choices',
'get_external_account',
'task_manager_bulk_reschedule',
'schedule_task_manager',
'classproperty',
'create_temporary_fifo',
'truncate_stdout',
'deepmerge',
]
def get_object_or_400(klass, *args, **kwargs):
'''
"""
Return a single object from the given model or queryset based on the query
params, otherwise raise an exception that will return in a 400 response.
'''
"""
from django.shortcuts import _get_queryset
queryset = _get_queryset(klass)
try:
return queryset.get(*args, **kwargs)
@@ -88,28 +120,28 @@ def to_python_boolean(value, allow_none=False):
def camelcase_to_underscore(s):
'''
"""
Convert CamelCase names to lowercase_with_underscore.
'''
"""
s = re.sub(r'(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))', '_\\1', s)
return s.lower().strip('_')
def underscore_to_camelcase(s):
'''
"""
Convert lowercase_with_underscore names to CamelCase.
'''
"""
return ''.join(x.capitalize() or '_' for x in s.split('_'))
class RequireDebugTrueOrTest(logging.Filter):
'''
"""
Logging filter to output when in DEBUG mode or running tests.
'''
"""
def filter(self, record):
from django.conf import settings
return settings.DEBUG or settings.IS_TESTING()
@@ -119,13 +151,14 @@ class IllegalArgumentError(ValueError):
def get_memoize_cache():
from django.core.cache import cache
return cache
def memoize(ttl=60, cache_key=None, track_function=False, cache=None):
'''
"""
Decorator to wrap a function and cache its result.
'''
"""
if cache_key and track_function:
raise IllegalArgumentError("Can not specify cache_key when track_function is True")
cache = cache or get_memoize_cache()
@@ -164,13 +197,12 @@ def memoize_delete(function_name):
@memoize()
def get_ansible_version():
'''
"""
Return Ansible version installed.
Ansible path needs to be provided to account for custom virtual environments
'''
"""
try:
proc = subprocess.Popen(['ansible', '--version'],
stdout=subprocess.PIPE)
proc = subprocess.Popen(['ansible', '--version'], stdout=subprocess.PIPE)
result = smart_str(proc.communicate()[0])
return result.split('\n')[0].replace('ansible', '').strip()
except Exception:
@@ -178,12 +210,14 @@ def get_ansible_version():
def get_awx_version():
'''
"""
Return AWX version as reported by setuptools.
'''
"""
from awx import __version__
try:
import pkg_resources
return pkg_resources.require('awx')[0].version
except Exception:
return __version__
@@ -193,17 +227,14 @@ def get_awx_http_client_headers():
license = get_license().get('license_type', 'UNLICENSED')
headers = {
'Content-Type': 'application/json',
'User-Agent': '{} {} ({})'.format(
'AWX' if license == 'open' else 'Red Hat Ansible Tower',
get_awx_version(),
license
)
'User-Agent': '{} {} ({})'.format('AWX' if license == 'open' else 'Red Hat Ansible Tower', get_awx_version(), license),
}
return headers
def get_licenser(*args, **kwargs):
from awx.main.utils.licensing import Licenser, OpenLicense
try:
if os.path.exists('/var/lib/awx/.tower_version'):
return Licenser(*args, **kwargs)
@@ -213,14 +244,13 @@ def get_licenser(*args, **kwargs):
raise ValueError(_('Error importing Tower License: %s') % e)
def update_scm_url(scm_type, url, username=True, password=True,
check_special_cases=True, scp_format=False):
'''
def update_scm_url(scm_type, url, username=True, password=True, check_special_cases=True, scp_format=False):
"""
Update the given SCM URL to add/replace/remove the username/password. When
username/password is True, preserve existing username/password, when
False (None, '', etc.), remove any existing username/password, otherwise
replace username/password. Also validates the given URL.
'''
"""
# Handle all of the URL formats supported by the SCM systems:
# git: https://www.kernel.org/pub/software/scm/git/docs/git-clone.html#URLS
# svn: http://svnbook.red-bean.com/en/1.7/svn-book.html#svn.advanced.reposurls
@@ -246,9 +276,9 @@ def update_scm_url(scm_type, url, username=True, password=True,
if hostpath.count(':') > 1:
raise ValueError(_('Invalid %s URL') % scm_type)
host, path = hostpath.split(':', 1)
#if not path.startswith('/') and not path.startswith('~/'):
# if not path.startswith('/') and not path.startswith('~/'):
# path = '~/%s' % path
#if path.startswith('/'):
# if path.startswith('/'):
# path = path.lstrip('/')
hostpath = '/'.join([host, path])
modified_url = '@'.join(filter(None, [userpass, hostpath]))
@@ -297,18 +327,17 @@ def update_scm_url(scm_type, url, username=True, password=True,
if scm_type == 'git' and parts.scheme.endswith('ssh') and parts.hostname in special_git_hosts and netloc_username != 'git':
raise ValueError(_('Username must be "git" for SSH access to %s.') % parts.hostname)
if scm_type == 'git' and parts.scheme.endswith('ssh') and parts.hostname in special_git_hosts and netloc_password:
#raise ValueError('Password not allowed for SSH access to %s.' % parts.hostname)
# raise ValueError('Password not allowed for SSH access to %s.' % parts.hostname)
netloc_password = ''
if netloc_username and parts.scheme != 'file' and scm_type not in ("insights", "archive"):
netloc = u':'.join([urllib.parse.quote(x,safe='') for x in (netloc_username, netloc_password) if x])
netloc = u':'.join([urllib.parse.quote(x, safe='') for x in (netloc_username, netloc_password) if x])
else:
netloc = u''
netloc = u'@'.join(filter(None, [netloc, parts.hostname]))
if parts.port:
netloc = u':'.join([netloc, str(parts.port)])
new_url = urllib.parse.urlunsplit([parts.scheme, netloc, parts.path,
parts.query, parts.fragment])
new_url = urllib.parse.urlunsplit([parts.scheme, netloc, parts.path, parts.query, parts.fragment])
if scp_format and parts.scheme == 'git+ssh':
new_url = new_url.replace('git+ssh://', '', 1).replace('/', ':', 1)
return new_url
@@ -322,11 +351,7 @@ def get_allowed_fields(obj, serializer_mapping):
else:
allowed_fields = [x.name for x in obj._meta.fields]
ACTIVITY_STREAM_FIELD_EXCLUSIONS = {
'user': ['last_login'],
'oauth2accesstoken': ['last_used'],
'oauth2application': ['client_secret']
}
ACTIVITY_STREAM_FIELD_EXCLUSIONS = {'user': ['last_login'], 'oauth2accesstoken': ['last_used'], 'oauth2application': ['client_secret']}
model_name = obj._meta.model_name
fields_excluded = ACTIVITY_STREAM_FIELD_EXCLUSIONS.get(model_name, [])
# see definition of from_db for CredentialType
@@ -347,10 +372,7 @@ def _convert_model_field_for_display(obj, field_name, password_fields=None):
return '<missing {}>-{}'.format(obj._meta.verbose_name, getattr(obj, '{}_id'.format(field_name)))
if password_fields is None:
password_fields = set(getattr(type(obj), 'PASSWORD_FIELDS', [])) | set(['password'])
if field_name in password_fields or (
isinstance(field_val, str) and
field_val.startswith('$encrypted$')
):
if field_name in password_fields or (isinstance(field_val, str) and field_val.startswith('$encrypted$')):
return u'hidden'
if hasattr(obj, 'display_%s' % field_name):
field_val = getattr(obj, 'display_%s' % field_name)()
@@ -373,9 +395,9 @@ def model_instance_diff(old, new, serializer_mapping=None):
"""
from django.db.models import Model
if not(old is None or isinstance(old, Model)):
if not (old is None or isinstance(old, Model)):
raise TypeError('The supplied old instance is not a valid model instance.')
if not(new is None or isinstance(new, Model)):
if not (new is None or isinstance(new, Model)):
raise TypeError('The supplied new instance is not a valid model instance.')
old_password_fields = set(getattr(type(old), 'PASSWORD_FIELDS', [])) | set(['password'])
new_password_fields = set(getattr(type(new), 'PASSWORD_FIELDS', [])) | set(['password'])
@@ -417,6 +439,7 @@ class CharPromptDescriptor:
"""Class used for identifying nullable launch config fields from class
ex. Schedule.limit
"""
def __init__(self, field):
self.field = field
@@ -426,6 +449,7 @@ class NullablePromptPseudoField:
Interface for pseudo-property stored in `char_prompts` dict
Used in LaunchTimeConfig and submodels, defined here to avoid circular imports
"""
def __init__(self, field_name):
self.field_name = field_name
@@ -447,10 +471,10 @@ class NullablePromptPseudoField:
def copy_model_by_class(obj1, Class2, fields, kwargs):
'''
"""
Creates a new unsaved object of type Class2 using the fields from obj1
values in kwargs can override obj1
'''
"""
create_kwargs = {}
for field_name in fields:
descriptor = getattr(Class2, field_name)
@@ -500,11 +524,11 @@ def copy_model_by_class(obj1, Class2, fields, kwargs):
def copy_m2m_relationships(obj1, obj2, fields, kwargs=None):
'''
"""
In-place operation.
Given two saved objects, copies related objects from obj1
to obj2 to field of same name, if field occurs in `fields`
'''
"""
for field_name in fields:
if hasattr(obj1, field_name):
try:
@@ -526,17 +550,17 @@ def copy_m2m_relationships(obj1, obj2, fields, kwargs=None):
def get_type_for_model(model):
'''
"""
Return type name for a given model class.
'''
"""
opts = model._meta.concrete_model._meta
return camelcase_to_underscore(opts.object_name)
def get_model_for_type(type_name):
'''
"""
Return model class for a given type name.
'''
"""
model_str = underscore_to_camelcase(type_name)
if model_str == 'User':
use_app = 'auth'
@@ -546,7 +570,7 @@ def get_model_for_type(type_name):
def prefetch_page_capabilities(model, page, prefetch_list, user):
'''
"""
Given a `page` list of objects, a nested dictionary of user_capabilities
are returned by id, ex.
{
@@ -565,7 +589,7 @@ def prefetch_page_capabilities(model, page, prefetch_list, user):
prefetch_list = [{'copy': ['inventory.admin', 'project.admin']}]
--> prefetch logical combination of admin permission to inventory AND
project, put into cache dictionary as "copy"
'''
"""
page_ids = [obj.id for obj in page]
mapping = {}
for obj in page:
@@ -592,9 +616,9 @@ def prefetch_page_capabilities(model, page, prefetch_list, user):
parent_model = model
for subpath in role_path.split('.')[:-1]:
parent_model = parent_model._meta.get_field(subpath).related_model
filter_args.append(Q(
Q(**{'%s__pk__in' % res_path: parent_model.accessible_pk_qs(user, '%s_role' % role_type)}) |
Q(**{'%s__isnull' % res_path: True})))
filter_args.append(
Q(Q(**{'%s__pk__in' % res_path: parent_model.accessible_pk_qs(user, '%s_role' % role_type)}) | Q(**{'%s__isnull' % res_path: True}))
)
else:
role_type = role_path
filter_args.append(Q(**{'pk__in': model.accessible_pk_qs(user, '%s_role' % role_type)}))
@@ -625,19 +649,16 @@ def validate_vars_type(vars_obj):
data_type = vars_type.__name__
else:
data_type = str(vars_type)
raise AssertionError(
_('Input type `{data_type}` is not a dictionary').format(
data_type=data_type)
)
raise AssertionError(_('Input type `{data_type}` is not a dictionary').format(data_type=data_type))
def parse_yaml_or_json(vars_str, silent_failure=True):
'''
"""
Attempt to parse a string of variables.
First, with JSON parser, if that fails, then with PyYAML.
If both attempts fail, return an empty dictionary if `silent_failure`
is True, re-raise combination error if `silent_failure` if False.
'''
"""
if isinstance(vars_str, dict):
return vars_str
elif isinstance(vars_str, str) and vars_str == '""':
@@ -658,21 +679,19 @@ def parse_yaml_or_json(vars_str, silent_failure=True):
try:
json.dumps(vars_dict)
except (ValueError, TypeError, AssertionError) as json_err2:
raise ParseError(_(
'Variables not compatible with JSON standard (error: {json_error})').format(
json_error=str(json_err2)))
raise ParseError(_('Variables not compatible with JSON standard (error: {json_error})').format(json_error=str(json_err2)))
except (yaml.YAMLError, TypeError, AttributeError, AssertionError) as yaml_err:
if silent_failure:
return {}
raise ParseError(_(
'Cannot parse as JSON (error: {json_error}) or '
'YAML (error: {yaml_error}).').format(
json_error=str(json_err), yaml_error=str(yaml_err)))
raise ParseError(
_('Cannot parse as JSON (error: {json_error}) or ' 'YAML (error: {yaml_error}).').format(json_error=str(json_err), yaml_error=str(yaml_err))
)
return vars_dict
def get_cpu_capacity():
from django.conf import settings
settings_forkcpu = getattr(settings, 'SYSTEM_TASK_FORKS_CPU', None)
env_forkcpu = os.getenv('SYSTEM_TASK_FORKS_CPU', None)
@@ -697,6 +716,7 @@ def get_cpu_capacity():
def get_mem_capacity():
from django.conf import settings
settings_forkmem = getattr(settings, 'SYSTEM_TASK_FORKS_MEM', None)
env_forkmem = os.getenv('SYSTEM_TASK_FORKS_MEM', None)
@@ -720,10 +740,11 @@ def get_mem_capacity():
def get_system_task_capacity(scale=Decimal(1.0), cpu_capacity=None, mem_capacity=None):
'''
"""
Measure system memory and use it as a baseline for determining the system's capacity
'''
"""
from django.conf import settings
settings_forks = getattr(settings, 'SYSTEM_TASK_FORKS_CAPACITY', None)
env_forks = os.getenv('SYSTEM_TASK_FORKS_CAPACITY', None)
@@ -749,9 +770,9 @@ _task_manager = threading.local()
@contextlib.contextmanager
def ignore_inventory_computed_fields():
'''
"""
Context manager to ignore updating inventory computed fields.
'''
"""
try:
previous_value = getattr(_inventory_updates, 'is_updating', False)
_inventory_updates.is_updating = True
@@ -763,14 +784,14 @@ def ignore_inventory_computed_fields():
def _schedule_task_manager():
from awx.main.scheduler.tasks import run_task_manager
from django.db import connection
# runs right away if not in transaction
connection.on_commit(lambda: run_task_manager.delay())
@contextlib.contextmanager
def task_manager_bulk_reschedule():
"""Context manager to avoid submitting task multiple times.
"""
"""Context manager to avoid submitting task multiple times."""
try:
previous_flag = getattr(_task_manager, 'bulk_reschedule', False)
previous_value = getattr(_task_manager, 'needs_scheduling', False)
@@ -793,9 +814,9 @@ def schedule_task_manager():
@contextlib.contextmanager
def ignore_inventory_group_removal():
'''
"""
Context manager to ignore moving groups/hosts when group is deleted.
'''
"""
try:
previous_value = getattr(_inventory_updates, 'is_removing', False)
_inventory_updates.is_removing = True
@@ -806,12 +827,12 @@ def ignore_inventory_group_removal():
@contextlib.contextmanager
def set_environ(**environ):
'''
"""
Temporarily set the process environment variables.
>>> with set_environ(FOO='BAR'):
... assert os.environ['FOO'] == 'BAR'
'''
"""
old_environ = os.environ.copy()
try:
os.environ.update(environ)
@@ -823,14 +844,14 @@ def set_environ(**environ):
@memoize()
def check_proot_installed():
'''
"""
Check that proot is installed.
'''
"""
from django.conf import settings
cmd = [getattr(settings, 'AWX_PROOT_CMD', 'bwrap'), '--version']
try:
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
proc.communicate()
return bool(proc.returncode == 0)
except (OSError, ValueError) as e:
@@ -840,17 +861,18 @@ def check_proot_installed():
def build_proot_temp_dir():
'''
"""
Create a temporary directory for proot to use.
'''
"""
from django.conf import settings
path = tempfile.mkdtemp(prefix='awx_proot_', dir=settings.AWX_PROOT_BASE_PATH)
os.chmod(path, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
return path
def wrap_args_with_proot(args, cwd, **kwargs):
'''
"""
Wrap existing command line with proot to restrict access to:
- AWX_PROOT_BASE_PATH (generally, /tmp) (except for own /tmp files)
For non-isolated nodes:
@@ -858,14 +880,14 @@ def wrap_args_with_proot(args, cwd, **kwargs):
- /var/lib/awx (except for current project)
- /var/log/tower
- /var/log/supervisor
'''
"""
from django.conf import settings
cwd = os.path.realpath(cwd)
new_args = [getattr(settings, 'AWX_PROOT_CMD', 'bwrap'), '--unshare-pid', '--dev-bind', '/', '/', '--proc', '/proc']
hide_paths = [settings.AWX_PROOT_BASE_PATH]
if not kwargs.get('isolated'):
hide_paths.extend(['/etc/tower', '/var/lib/awx', '/var/log', '/etc/ssh',
settings.PROJECTS_ROOT, settings.JOBOUTPUT_ROOT])
hide_paths.extend(['/etc/tower', '/var/lib/awx', '/var/log', '/etc/ssh', settings.PROJECTS_ROOT, settings.JOBOUTPUT_ROOT])
hide_paths.extend(getattr(settings, 'AWX_PROOT_HIDE_PATHS', None) or [])
for path in sorted(set(hide_paths)):
if not os.path.exists(path):
@@ -878,18 +900,14 @@ def wrap_args_with_proot(args, cwd, **kwargs):
handle, new_path = tempfile.mkstemp(dir=kwargs['proot_temp_dir'])
os.close(handle)
os.chmod(new_path, stat.S_IRUSR | stat.S_IWUSR)
new_args.extend(['--bind', '%s' %(new_path,), '%s' % (path,)])
new_args.extend(['--bind', '%s' % (new_path,), '%s' % (path,)])
if kwargs.get('isolated'):
show_paths = [kwargs['private_data_dir']]
elif 'private_data_dir' in kwargs:
show_paths = [cwd, kwargs['private_data_dir']]
else:
show_paths = [cwd]
for venv in (
settings.ANSIBLE_VENV_PATH,
settings.AWX_VENV_PATH,
kwargs.get('proot_custom_virtualenv')
):
for venv in (settings.ANSIBLE_VENV_PATH, settings.AWX_VENV_PATH, kwargs.get('proot_custom_virtualenv')):
if venv:
new_args.extend(['--ro-bind', venv, venv])
show_paths.extend(getattr(settings, 'AWX_PROOT_SHOW_PATHS', None) or [])
@@ -913,9 +931,9 @@ def wrap_args_with_proot(args, cwd, **kwargs):
def get_pk_from_dict(_dict, key):
'''
"""
Helper for obtaining a pk from user data dict or None if not present.
'''
"""
try:
val = _dict[key]
if isinstance(val, object) and hasattr(val, 'id'):
@@ -966,6 +984,7 @@ def get_current_apps():
def get_custom_venv_choices(custom_paths=None):
from django.conf import settings
custom_paths = custom_paths or settings.CUSTOM_VENV_PATHS
all_venv_paths = [settings.BASE_VENV_PATH] + custom_paths
custom_venv_choices = []
@@ -973,13 +992,15 @@ def get_custom_venv_choices(custom_paths=None):
for custom_venv_path in all_venv_paths:
try:
if os.path.exists(custom_venv_path):
custom_venv_choices.extend([
os.path.join(custom_venv_path, x, '')
for x in os.listdir(custom_venv_path)
if x != 'awx' and
os.path.isdir(os.path.join(custom_venv_path, x)) and
os.path.exists(os.path.join(custom_venv_path, x, 'bin', 'activate'))
])
custom_venv_choices.extend(
[
os.path.join(custom_venv_path, x, '')
for x in os.listdir(custom_venv_path)
if x != 'awx'
and os.path.isdir(os.path.join(custom_venv_path, x))
and os.path.exists(os.path.join(custom_venv_path, x, 'bin', 'activate'))
]
)
except Exception:
logger.exception("Encountered an error while discovering custom virtual environments.")
return custom_venv_choices
@@ -1002,20 +1023,19 @@ def extract_ansible_vars(extra_vars):
def get_search_fields(model):
fields = []
for field in model._meta.fields:
if field.name in ('username', 'first_name', 'last_name', 'email',
'name', 'description'):
if field.name in ('username', 'first_name', 'last_name', 'email', 'name', 'description'):
fields.append(field.name)
return fields
def has_model_field_prefetched(model_obj, field_name):
# NOTE: Update this function if django internal implementation changes.
return getattr(getattr(model_obj, field_name, None),
'prefetch_cache_name', '') in getattr(model_obj, '_prefetched_objects_cache', {})
return getattr(getattr(model_obj, field_name, None), 'prefetch_cache_name', '') in getattr(model_obj, '_prefetched_objects_cache', {})
def get_external_account(user):
from django.conf import settings
account_type = None
if getattr(settings, 'AUTH_LDAP_SERVER_URI', None):
try:
@@ -1023,20 +1043,20 @@ def get_external_account(user):
account_type = "ldap"
except AttributeError:
pass
if (getattr(settings, 'SOCIAL_AUTH_GOOGLE_OAUTH2_KEY', None) or
getattr(settings, 'SOCIAL_AUTH_GITHUB_KEY', None) or
getattr(settings, 'SOCIAL_AUTH_GITHUB_ORG_KEY', None) or
getattr(settings, 'SOCIAL_AUTH_GITHUB_TEAM_KEY', None) or
getattr(settings, 'SOCIAL_AUTH_SAML_ENABLED_IDPS', None)) and user.social_auth.all():
if (
getattr(settings, 'SOCIAL_AUTH_GOOGLE_OAUTH2_KEY', None)
or getattr(settings, 'SOCIAL_AUTH_GITHUB_KEY', None)
or getattr(settings, 'SOCIAL_AUTH_GITHUB_ORG_KEY', None)
or getattr(settings, 'SOCIAL_AUTH_GITHUB_TEAM_KEY', None)
or getattr(settings, 'SOCIAL_AUTH_SAML_ENABLED_IDPS', None)
) and user.social_auth.all():
account_type = "social"
if (getattr(settings, 'RADIUS_SERVER', None) or
getattr(settings, 'TACACSPLUS_HOST', None)) and user.enterprise_auth.all():
if (getattr(settings, 'RADIUS_SERVER', None) or getattr(settings, 'TACACSPLUS_HOST', None)) and user.enterprise_auth.all():
account_type = "enterprise"
return account_type
class classproperty:
def __init__(self, fget=None, fset=None, fdel=None, doc=None):
self.fget = fget
self.fset = fset
@@ -1058,10 +1078,7 @@ def create_temporary_fifo(data):
path = os.path.join(tempfile.mkdtemp(), next(tempfile._get_candidate_names()))
os.mkfifo(path, stat.S_IRUSR | stat.S_IWUSR)
threading.Thread(
target=lambda p, d: open(p, 'wb').write(d),
args=(path, data)
).start()
threading.Thread(target=lambda p, d: open(p, 'wb').write(d), args=(path, data)).start()
return path
@@ -1071,7 +1088,7 @@ def truncate_stdout(stdout, size):
if size <= 0 or len(stdout) <= size:
return stdout
stdout = stdout[:(size - 1)] + u'\u2026'
stdout = stdout[: (size - 1)] + u'\u2026'
set_count, reset_count = 0, 0
for m in ANSI_SGR_PATTERN.finditer(stdout):
if m.group() == u'\u001b[0m':
@@ -1092,8 +1109,7 @@ def deepmerge(a, b):
{'first': {'all_rows': {'fail': 'cat', 'number': '5', 'pass': 'dog'}}}
"""
if isinstance(a, dict) and isinstance(b, dict):
return dict([(k, deepmerge(a.get(k), b.get(k)))
for k in set(a.keys()).union(b.keys())])
return dict([(k, deepmerge(a.get(k), b.get(k))) for k in set(a.keys()).union(b.keys())])
elif b is None:
return a
else:

View File

@@ -7,10 +7,14 @@ from itertools import chain
def get_all_field_names(model):
# Implements compatibility with _meta.get_all_field_names
# See: https://docs.djangoproject.com/en/1.11/ref/models/meta/#migrating-from-the-old-api
return list(set(chain.from_iterable(
(field.name, field.attname) if hasattr(field, 'attname') else (field.name,)
for field in model._meta.get_fields()
# For complete backwards compatibility, you may want to exclude
# GenericForeignKey from the results.
if not (field.many_to_one and field.related_model is None)
)))
return list(
set(
chain.from_iterable(
(field.name, field.attname) if hasattr(field, 'attname') else (field.name,)
for field in model._meta.get_fields()
# For complete backwards compatibility, you may want to exclude
# GenericForeignKey from the results.
if not (field.many_to_one and field.related_model is None)
)
)
)

View File

@@ -1,6 +1,8 @@
from django.contrib.contenttypes.models import ContentType
from django.db.models.deletion import (
DO_NOTHING, Collector, get_candidate_relations_to_delete,
DO_NOTHING,
Collector,
get_candidate_relations_to_delete,
)
from collections import Counter, OrderedDict
from django.db import transaction
@@ -12,17 +14,18 @@ def bulk_related_objects(field, objs, using):
"""
Return all objects related to ``objs`` via this ``GenericRelation``.
"""
return field.remote_field.model._base_manager.db_manager(using).filter(**{
"%s__pk" % field.content_type_field_name: ContentType.objects.db_manager(using).get_for_model(
field.model, for_concrete_model=field.for_concrete_model).pk,
"%s__in" % field.object_id_field_name: list(objs.values_list('pk', flat=True))
})
return field.remote_field.model._base_manager.db_manager(using).filter(
**{
"%s__pk"
% field.content_type_field_name: ContentType.objects.db_manager(using).get_for_model(field.model, for_concrete_model=field.for_concrete_model).pk,
"%s__in" % field.object_id_field_name: list(objs.values_list('pk', flat=True)),
}
)
def pre_delete(qs):
# taken from .delete method in django.db.models.query.py
assert qs.query.can_filter(), \
"Cannot use 'limit' or 'offset' with delete."
assert qs.query.can_filter(), "Cannot use 'limit' or 'offset' with delete."
if qs._fields is not None:
raise TypeError("Cannot call delete() after .values() or .values_list()")
@@ -42,7 +45,6 @@ def pre_delete(qs):
class AWXCollector(Collector):
def add(self, objs, source=None, nullable=False, reverse_dependency=False):
"""
Add 'objs' to the collection of objects to be deleted. If the call is
@@ -62,8 +64,7 @@ class AWXCollector(Collector):
if source is not None and not nullable:
if reverse_dependency:
source, model = model, source
self.dependencies.setdefault(
source._meta.concrete_model, set()).add(model._meta.concrete_model)
self.dependencies.setdefault(source._meta.concrete_model, set()).add(model._meta.concrete_model)
return objs
def add_field_update(self, field, value, objs):
@@ -78,8 +79,7 @@ class AWXCollector(Collector):
self.field_updates[model].setdefault((field, value), [])
self.field_updates[model][(field, value)].append(objs)
def collect(self, objs, source=None, nullable=False, collect_related=True,
source_attr=None, reverse_dependency=False, keep_parents=False):
def collect(self, objs, source=None, nullable=False, collect_related=True, source_attr=None, reverse_dependency=False, keep_parents=False):
"""
Add 'objs' to the collection of objects to be deleted as well as all
parent instances. 'objs' must be a homogeneous iterable collection of
@@ -104,8 +104,7 @@ class AWXCollector(Collector):
if self.can_fast_delete(objs):
self.fast_deletes.append(objs)
return
new_objs = self.add(objs, source, nullable,
reverse_dependency=reverse_dependency)
new_objs = self.add(objs, source, nullable, reverse_dependency=reverse_dependency)
if not new_objs.exists():
return
@@ -117,10 +116,8 @@ class AWXCollector(Collector):
concrete_model = model._meta.concrete_model
for ptr in concrete_model._meta.parents.keys():
if ptr:
parent_objs = ptr.objects.filter(pk__in = new_objs.values_list('pk', flat=True))
self.collect(parent_objs, source=model,
collect_related=False,
reverse_dependency=True)
parent_objs = ptr.objects.filter(pk__in=new_objs.values_list('pk', flat=True))
self.collect(parent_objs, source=model, collect_related=False, reverse_dependency=True)
if collect_related:
parents = model._meta.parents
for related in get_candidate_relations_to_delete(model._meta):
@@ -161,8 +158,7 @@ class AWXCollector(Collector):
for (field, value), instances in instances_for_fieldvalues.items():
for inst in instances:
query = sql.UpdateQuery(model)
query.update_batch(inst.values_list('pk', flat=True),
{field.name: value}, self.using)
query.update_batch(inst.values_list('pk', flat=True), {field.name: value}, self.using)
# fast deletes
for qs in self.fast_deletes:
count = qs._raw_delete(using=self.using)

View File

@@ -10,27 +10,23 @@ from cryptography.hazmat.backends import default_backend
from django.utils.encoding import smart_str, smart_bytes
__all__ = ['get_encryption_key',
'encrypt_field', 'decrypt_field',
'encrypt_value', 'decrypt_value',
'encrypt_dict']
__all__ = ['get_encryption_key', 'encrypt_field', 'decrypt_field', 'encrypt_value', 'decrypt_value', 'encrypt_dict']
logger = logging.getLogger('awx.main.utils.encryption')
class Fernet256(Fernet):
'''Not techincally Fernet, but uses the base of the Fernet spec and uses AES-256-CBC
"""Not techincally Fernet, but uses the base of the Fernet spec and uses AES-256-CBC
instead of AES-128-CBC. All other functionality remain identical.
'''
"""
def __init__(self, key, backend=None):
if backend is None:
backend = default_backend()
key = base64.urlsafe_b64decode(key)
if len(key) != 64:
raise ValueError(
"Fernet key must be 64 url-safe base64-encoded bytes."
)
raise ValueError("Fernet key must be 64 url-safe base64-encoded bytes.")
self._signing_key = key[:32]
self._encryption_key = key[32:]
@@ -38,15 +34,16 @@ class Fernet256(Fernet):
def get_encryption_key(field_name, pk=None, secret_key=None):
'''
"""
Generate key for encrypted password based on field name,
``settings.SECRET_KEY``, and instance pk (if available).
:param pk: (optional) the primary key of the model object;
can be omitted in situations where you're encrypting a setting
that is not database-persistent (like a read-only setting)
'''
"""
from django.conf import settings
h = hashlib.sha512()
h.update(smart_bytes(secret_key or settings.SECRET_KEY))
if pk is not None:
@@ -100,9 +97,9 @@ def encrypt_field(instance, field_name, ask=False, subfield=None, secret_key=Non
# 2. Decrypting them using the *old* SECRET_KEY
# 3. Storing newly encrypted values using the *newly generated* SECRET_KEY
#
'''
"""
Return content of the given instance and field name encrypted.
'''
"""
try:
value = instance.inputs[field_name]
except (TypeError, AttributeError):
@@ -117,11 +114,7 @@ def encrypt_field(instance, field_name, ask=False, subfield=None, secret_key=Non
value = smart_str(value)
if not value or value.startswith('$encrypted$') or (ask and value == 'ASK'):
return value
key = get_encryption_key(
field_name,
getattr(instance, 'pk', None),
secret_key=secret_key
)
key = get_encryption_key(field_name, getattr(instance, 'pk', None), secret_key=secret_key)
f = Fernet256(key)
encrypted = f.encrypt(smart_bytes(value))
b64data = smart_str(base64.b64encode(encrypted))
@@ -130,11 +123,11 @@ def encrypt_field(instance, field_name, ask=False, subfield=None, secret_key=Non
def decrypt_value(encryption_key, value):
raw_data = value[len('$encrypted$'):]
raw_data = value[len('$encrypted$') :]
# If the encrypted string contains a UTF8 marker, discard it
utf8 = raw_data.startswith('UTF8$')
if utf8:
raw_data = raw_data[len('UTF8$'):]
raw_data = raw_data[len('UTF8$') :]
algo, b64data = raw_data.split('$', 1)
if algo != 'AESCBC':
raise ValueError('unsupported algorithm: %s' % algo)
@@ -145,9 +138,9 @@ def decrypt_value(encryption_key, value):
def decrypt_field(instance, field_name, subfield=None, secret_key=None):
'''
"""
Return content of the given instance and field name decrypted.
'''
"""
try:
value = instance.inputs[field_name]
except (TypeError, AttributeError):
@@ -160,11 +153,7 @@ def decrypt_field(instance, field_name, subfield=None, secret_key=None):
value = smart_str(value)
if not value or not value.startswith('$encrypted$'):
return value
key = get_encryption_key(
field_name,
getattr(instance, 'pk', None),
secret_key=secret_key
)
key = get_encryption_key(field_name, getattr(instance, 'pk', None), secret_key=secret_key)
try:
return smart_str(decrypt_value(key, value))
@@ -176,16 +165,16 @@ def decrypt_field(instance, field_name, subfield=None, secret_key=None):
instance.__class__.__name__,
getattr(instance, 'pk', None),
field_name,
exc_info=True
exc_info=True,
)
raise
def encrypt_dict(data, fields):
'''
"""
Encrypts all of the dictionary values in `data` under the keys in `fields`
in-place operation on `data`
'''
"""
encrypt_fields = set(data.keys()).intersection(fields)
for key in encrypt_fields:
data[key] = encrypt_value(data[key])

View File

@@ -26,15 +26,17 @@ def construct_rsyslog_conf_template(settings=settings):
max_bytes = settings.MAX_EVENT_RES_DATA
if settings.LOG_AGGREGATOR_RSYSLOGD_DEBUG:
parts.append('$DebugLevel 2')
parts.extend([
'$WorkDirectory /var/lib/awx/rsyslog',
f'$MaxMessageSize {max_bytes}',
'$IncludeConfig /var/lib/awx/rsyslog/conf.d/*.conf',
f'main_queue(queue.spoolDirectory="{spool_directory}" queue.maxdiskspace="{max_disk_space}g" queue.type="Disk" queue.filename="awx-external-logger-backlog")', # noqa
'module(load="imuxsock" SysSock.Use="off")',
'input(type="imuxsock" Socket="' + settings.LOGGING['handlers']['external_logger']['address'] + '" unlink="on" RateLimit.Burst="0")',
'template(name="awx" type="string" string="%rawmsg-after-pri%")',
])
parts.extend(
[
'$WorkDirectory /var/lib/awx/rsyslog',
f'$MaxMessageSize {max_bytes}',
'$IncludeConfig /var/lib/awx/rsyslog/conf.d/*.conf',
f'main_queue(queue.spoolDirectory="{spool_directory}" queue.maxdiskspace="{max_disk_space}g" queue.type="Disk" queue.filename="awx-external-logger-backlog")', # noqa
'module(load="imuxsock" SysSock.Use="off")',
'input(type="imuxsock" Socket="' + settings.LOGGING['handlers']['external_logger']['address'] + '" unlink="on" RateLimit.Burst="0")',
'template(name="awx" type="string" string="%rawmsg-after-pri%")',
]
)
def escape_quotes(x):
return x.replace('"', '\\"')
@@ -43,7 +45,7 @@ def construct_rsyslog_conf_template(settings=settings):
parts.append('action(type="omfile" file="/dev/null")') # rsyslog needs *at least* one valid action to start
tmpl = '\n'.join(parts)
return tmpl
if protocol.startswith('http'):
scheme = 'https'
# urlparse requires '//' to be provided if scheme is not specified
@@ -75,7 +77,7 @@ def construct_rsyslog_conf_template(settings=settings):
f'skipverifyhost="{skip_verify}"',
'action.resumeRetryCount="-1"',
'template="awx"',
f'action.resumeInterval="{timeout}"'
f'action.resumeInterval="{timeout}"',
]
if error_log_file:
params.append(f'errorfile="{error_log_file}"')

View File

@@ -107,7 +107,6 @@ class ExternalLoggerEnabled(Filter):
class DynamicLevelFilter(Filter):
def filter(self, record):
"""Filters out logs that have a level below the threshold defined
by the databse setting LOG_AGGREGATOR_LEVEL
@@ -132,10 +131,10 @@ def string_to_type(t):
elif t == u'false':
return False
if re.search(r'^[-+]?[0-9]+$',t):
if re.search(r'^[-+]?[0-9]+$', t):
return int(t)
if re.search(r'^[-+]?[0-9]+\.[0-9]+$',t):
if re.search(r'^[-+]?[0-9]+\.[0-9]+$', t):
return float(t)
return t
@@ -158,12 +157,13 @@ class SmartFilter(object):
search_kwargs = self._expand_search(k, v)
if search_kwargs:
kwargs.update(search_kwargs)
q = reduce(lambda x, y: x | y, [models.Q(**{u'%s__icontains' % _k:_v}) for _k, _v in kwargs.items()])
q = reduce(lambda x, y: x | y, [models.Q(**{u'%s__icontains' % _k: _v}) for _k, _v in kwargs.items()])
self.result = Host.objects.filter(q)
else:
# detect loops and restrict access to sensitive fields
# this import is intentional here to avoid a circular import
from awx.api.filters import FieldLookupBackend
FieldLookupBackend().get_field_from_lookup(Host, k)
kwargs[k] = v
self.result = Host.objects.filter(**kwargs)
@@ -186,8 +186,10 @@ class SmartFilter(object):
accomplished using an allowed list or introspecting the
relationship refered to to see if it's a jsonb type.
'''
def _json_path_to_contains(self, k, v):
from awx.main.fields import JSONBField # avoid a circular import
if not k.startswith(SmartFilter.SEARCHABLE_RELATIONSHIP):
v = self.strip_quotes_traditional_logic(v)
return (k, v)
@@ -198,14 +200,9 @@ class SmartFilter(object):
if match == '__exact':
# appending __exact is basically a no-op, because that's
# what the query means if you leave it off
k = k[:-len(match)]
k = k[: -len(match)]
else:
logger.error(
'host_filter:{} does not support searching with {}'.format(
SmartFilter.SEARCHABLE_RELATIONSHIP,
match
)
)
logger.error('host_filter:{} does not support searching with {}'.format(SmartFilter.SEARCHABLE_RELATIONSHIP, match))
# Strip off leading relationship key
if k.startswith(SmartFilter.SEARCHABLE_RELATIONSHIP + '__'):
@@ -270,7 +267,7 @@ class SmartFilter(object):
# ="something"
if t_len > (v_offset + 2) and t[v_offset] == "\"" and t[v_offset + 2] == "\"":
v = u'"' + str(t[v_offset + 1]) + u'"'
#v = t[v_offset + 1]
# v = t[v_offset + 1]
# empty ""
elif t_len > (v_offset + 1):
v = u""
@@ -305,42 +302,38 @@ class SmartFilter(object):
search_kwargs[k] = v
return search_kwargs
class BoolBinOp(object):
def __init__(self, t):
self.result = None
i = 2
while i < len(t[0]):
'''
"""
Do NOT observe self.result. It will cause the sql query to be executed.
We do not want that. We only want to build the query.
'''
"""
if isinstance(self.result, type(None)):
self.result = t[0][0].result
right = t[0][i].result
self.result = self.execute_logic(self.result, right)
i += 2
class BoolAnd(BoolBinOp):
def execute_logic(self, left, right):
return left & right
class BoolOr(BoolBinOp):
def execute_logic(self, left, right):
return left | right
@classmethod
def query_from_string(cls, filter_string):
'''
"""
TODO:
* handle values with " via: a.b.c.d="hello\"world"
* handle keys with " via: a.\"b.c="yeah"
* handle key with __ in it
'''
"""
filter_string_raw = filter_string
filter_string = str(filter_string)
@@ -351,13 +344,16 @@ class SmartFilter(object):
atom_quoted = Literal('"') + Optional(atom_inside_quotes) + Literal('"')
EQUAL = Literal('=')
grammar = ((atom_quoted | atom) + EQUAL + Optional((atom_quoted | atom)))
grammar = (atom_quoted | atom) + EQUAL + Optional((atom_quoted | atom))
grammar.setParseAction(cls.BoolOperand)
boolExpr = infixNotation(grammar, [
("and", 2, opAssoc.LEFT, cls.BoolAnd),
("or", 2, opAssoc.LEFT, cls.BoolOr),
])
boolExpr = infixNotation(
grammar,
[
("and", 2, opAssoc.LEFT, cls.BoolAnd),
("or", 2, opAssoc.LEFT, cls.BoolOr),
],
)
try:
res = boolExpr.parseString('(' + filter_string + ')')
@@ -370,9 +366,7 @@ class SmartFilter(object):
raise RuntimeError("Parsing the filter_string %s went terribly wrong" % filter_string)
class DefaultCorrelationId(CorrelationId):
def filter(self, record):
guid = GuidMiddleware.get_guid() or '-'
if MODE == 'development':

View File

@@ -25,9 +25,10 @@ class JobLifeCycleFormatter(json_log_formatter.JSONFormatter):
class TimeFormatter(logging.Formatter):
'''
"""
Custom log formatter used for inventory imports
'''
"""
def __init__(self, start_time=None, **kwargs):
if start_time is None:
self.job_start = now()
@@ -81,10 +82,31 @@ class LogstashFormatterBase(logging.Formatter):
# The list contains all the attributes listed in
# http://docs.python.org/library/logging.html#logrecord-attributes
skip_list = (
'args', 'asctime', 'created', 'exc_info', 'exc_text', 'filename',
'funcName', 'id', 'levelname', 'levelno', 'lineno', 'module',
'msecs', 'msecs', 'message', 'msg', 'name', 'pathname', 'process',
'processName', 'relativeCreated', 'thread', 'threadName', 'extra')
'args',
'asctime',
'created',
'exc_info',
'exc_text',
'filename',
'funcName',
'id',
'levelname',
'levelno',
'lineno',
'module',
'msecs',
'msecs',
'message',
'msg',
'name',
'pathname',
'process',
'processName',
'relativeCreated',
'thread',
'threadName',
'extra',
)
easy_types = (str, bool, dict, float, int, list, type(None))
@@ -119,25 +141,21 @@ class LogstashFormatterBase(logging.Formatter):
class LogstashFormatter(LogstashFormatterBase):
def __init__(self, *args, **kwargs):
self.cluster_host_id = settings.CLUSTER_HOST_ID
self.tower_uuid = None
uuid = (
getattr(settings, 'LOG_AGGREGATOR_TOWER_UUID', None) or
getattr(settings, 'INSTALL_UUID', None)
)
uuid = getattr(settings, 'LOG_AGGREGATOR_TOWER_UUID', None) or getattr(settings, 'INSTALL_UUID', None)
if uuid:
self.tower_uuid = uuid
super(LogstashFormatter, self).__init__(*args, **kwargs)
def reformat_data_for_log(self, raw_data, kind=None):
'''
"""
Process dictionaries from various contexts (job events, activity stream
changes, etc.) to give meaningful information
Output a dictionary which will be passed in logstash or syslog format
to the logging receiver
'''
"""
if kind == 'activity_stream':
try:
raw_data['changes'] = json.loads(raw_data.get('changes', '{}'))
@@ -191,6 +209,7 @@ class LogstashFormatter(LogstashFormatterBase):
data_for_log['host_name'] = raw_data.get('host_name')
data_for_log['job_id'] = raw_data.get('job_id')
elif kind == 'performance':
def convert_to_type(t, val):
if t is float:
val = val[:-1] if val.endswith('s') else val
@@ -216,7 +235,7 @@ class LogstashFormatter(LogstashFormatterBase):
(float, 'X-API-Time'), # may end with an 's' "0.33s"
(float, 'X-API-Total-Time'),
(int, 'X-API-Query-Count'),
(float, 'X-API-Query-Time'), # may also end with an 's'
(float, 'X-API-Query-Time'), # may also end with an 's'
(str, 'X-API-Node'),
]
data_for_log['x_api'] = {k: convert_to_type(t, response[k]) for (t, k) in headers if k in response}
@@ -236,7 +255,7 @@ class LogstashFormatter(LogstashFormatterBase):
def get_extra_fields(self, record):
fields = super(LogstashFormatter, self).get_extra_fields(record)
if record.name.startswith('awx.analytics'):
log_kind = record.name[len('awx.analytics.'):]
log_kind = record.name[len('awx.analytics.') :]
fields = self.reformat_data_for_log(fields, kind=log_kind)
# General AWX metadata
fields['cluster_host_id'] = self.cluster_host_id
@@ -252,7 +271,6 @@ class LogstashFormatter(LogstashFormatterBase):
'@timestamp': stamp,
'message': record.getMessage(),
'host': self.host,
# Extra Fields
'level': record.levelname,
'logger_name': record.name,

View File

@@ -56,8 +56,7 @@ class SpecialInventoryHandler(logging.Handler):
as opposed to ansible-runner
"""
def __init__(self, event_handler, cancel_callback, job_timeout, verbosity,
start_time=None, counter=0, initial_line=0, **kwargs):
def __init__(self, event_handler, cancel_callback, job_timeout, verbosity, start_time=None, counter=0, initial_line=0, **kwargs):
self.event_handler = event_handler
self.cancel_callback = cancel_callback
self.job_timeout = job_timeout
@@ -89,12 +88,7 @@ class SpecialInventoryHandler(logging.Handler):
msg = self.format(record)
n_lines = len(msg.strip().split('\n')) # don't count line breaks at boundry of text
dispatch_data = dict(
created=now().isoformat(),
event='verbose',
counter=self.counter,
stdout=msg,
start_line=self._current_line,
end_line=self._current_line + n_lines
created=now().isoformat(), event='verbose', counter=self.counter, stdout=msg, start_line=self._current_line, end_line=self._current_line + n_lines
)
self._current_line += n_lines
@@ -120,10 +114,7 @@ if settings.COLOR_LOGS is True:
def format(self, record):
message = logging.StreamHandler.format(self, record)
return '\n'.join([
self.colorize(line, record)
for line in message.splitlines()
])
return '\n'.join([self.colorize(line, record) for line in message.splitlines()])
level_map = {
logging.DEBUG: (None, 'green', True),
@@ -132,6 +123,7 @@ if settings.COLOR_LOGS is True:
logging.ERROR: (None, 'red', True),
logging.CRITICAL: (None, 'red', True),
}
except ImportError:
# logutils is only used for colored logs in the dev environment
pass

View File

@@ -16,12 +16,7 @@
def filter_insights_api_response(platform_info, reports, remediations):
severity_mapping = {
1: 'INFO',
2: 'WARN',
3: 'ERROR',
4: 'CRITICAL'
}
severity_mapping = {1: 'INFO', 2: 'WARN', 3: 'ERROR', 4: 'CRITICAL'}
new_json = {
'platform_id': platform_info['id'],
@@ -29,10 +24,7 @@ def filter_insights_api_response(platform_info, reports, remediations):
'reports': [],
}
for rep in reports:
new_report = {
'rule': {},
'maintenance_actions': remediations
}
new_report = {'rule': {}, 'maintenance_actions': remediations}
rule = rep.get('rule') or {}
for k in ['description', 'summary']:
if k in rule:

View File

@@ -104,7 +104,7 @@ class Licenser(object):
license_date=0,
license_type="UNLICENSED",
product_name="Red Hat Ansible Automation Platform",
valid_key=False
valid_key=False,
)
def __init__(self, **kwargs):
@@ -128,11 +128,9 @@ class Licenser(object):
else:
self._unset_attrs()
def _unset_attrs(self):
self._attrs = self.UNLICENSED_DATA.copy()
def license_from_manifest(self, manifest):
def is_appropriate_manifest_sub(sub):
if sub['pool']['activeSubscription'] is False:
@@ -162,12 +160,12 @@ class Licenser(object):
license = dict()
for sub in manifest:
if not is_appropriate_manifest_sub(sub):
logger.warning("Subscription %s (%s) in manifest is not active or for another product" %
(sub['pool']['productName'], sub['pool']['productId']))
logger.warning("Subscription %s (%s) in manifest is not active or for another product" % (sub['pool']['productName'], sub['pool']['productId']))
continue
if not _can_aggregate(sub, license):
logger.warning("Subscription %s (%s) in manifest does not match other manifest subscriptions" %
(sub['pool']['productName'], sub['pool']['productId']))
logger.warning(
"Subscription %s (%s) in manifest does not match other manifest subscriptions" % (sub['pool']['productName'], sub['pool']['productId'])
)
continue
license.setdefault('sku', sub['pool']['productId'])
@@ -179,7 +177,7 @@ class Licenser(object):
license.setdefault('satellite', False)
# Use the nearest end date
endDate = parse_date(sub['endDate'])
currentEndDateStr = license.get('license_date', '4102462800') # 2100-01-01
currentEndDateStr = license.get('license_date', '4102462800') # 2100-01-01
currentEndDate = datetime.fromtimestamp(int(currentEndDateStr), timezone.utc)
if endDate < currentEndDate:
license['license_date'] = endDate.strftime('%s')
@@ -193,7 +191,6 @@ class Licenser(object):
settings.LICENSE = self._attrs
return self._attrs
def update(self, **kwargs):
# Update attributes of the current license.
if 'instance_count' in kwargs:
@@ -202,7 +199,6 @@ class Licenser(object):
kwargs['license_date'] = int(kwargs['license_date'])
self._attrs.update(kwargs)
def validate_rh(self, user, pw):
try:
host = 'https://' + str(self.config.get("server", "hostname"))
@@ -211,7 +207,7 @@ class Licenser(object):
host = None
if not host:
host = getattr(settings, 'REDHAT_CANDLEPIN_HOST', None)
if not user:
raise ValueError('subscriptions_username is required')
@@ -226,36 +222,25 @@ class Licenser(object):
return self.generate_license_options_from_entitlements(json)
return []
def get_rhsm_subs(self, host, user, pw):
verify = getattr(settings, 'REDHAT_CANDLEPIN_VERIFY', True)
json = []
try:
subs = requests.get(
'/'.join([host, 'subscription/users/{}/owners'.format(user)]),
verify=verify,
auth=(user, pw)
)
subs = requests.get('/'.join([host, 'subscription/users/{}/owners'.format(user)]), verify=verify, auth=(user, pw))
except requests.exceptions.ConnectionError as error:
raise error
except OSError as error:
raise OSError('Unable to open certificate bundle {}. Check that Ansible Tower is running on Red Hat Enterprise Linux.'.format(verify)) from error # noqa
raise OSError(
'Unable to open certificate bundle {}. Check that Ansible Tower is running on Red Hat Enterprise Linux.'.format(verify)
) from error # noqa
subs.raise_for_status()
for sub in subs.json():
resp = requests.get(
'/'.join([
host,
'subscription/owners/{}/pools/?match=*tower*'.format(sub['key'])
]),
verify=verify,
auth=(user, pw)
)
resp = requests.get('/'.join([host, 'subscription/owners/{}/pools/?match=*tower*'.format(sub['key'])]), verify=verify, auth=(user, pw))
resp.raise_for_status()
json.extend(resp.json())
return json
def get_satellite_subs(self, host, user, pw):
port = None
try:
@@ -268,25 +253,20 @@ class Licenser(object):
host = ':'.join([host, port])
json = []
try:
orgs = requests.get(
'/'.join([host, 'katello/api/organizations']),
verify=verify,
auth=(user, pw)
)
orgs = requests.get('/'.join([host, 'katello/api/organizations']), verify=verify, auth=(user, pw))
except requests.exceptions.ConnectionError as error:
raise error
except OSError as error:
raise OSError('Unable to open certificate bundle {}. Check that Ansible Tower is running on Red Hat Enterprise Linux.'.format(verify)) from error # noqa
raise OSError(
'Unable to open certificate bundle {}. Check that Ansible Tower is running on Red Hat Enterprise Linux.'.format(verify)
) from error # noqa
orgs.raise_for_status()
for org in orgs.json()['results']:
resp = requests.get(
'/'.join([
host,
'/katello/api/organizations/{}/subscriptions/?search=Red Hat Ansible Automation'.format(org['id'])
]),
'/'.join([host, '/katello/api/organizations/{}/subscriptions/?search=Red Hat Ansible Automation'.format(org['id'])]),
verify=verify,
auth=(user, pw)
auth=(user, pw),
)
resp.raise_for_status()
results = resp.json()['results']
@@ -307,13 +287,11 @@ class Licenser(object):
json.append(license)
return json
def is_appropriate_sat_sub(self, sub):
if 'Red Hat Ansible Automation' not in sub['subscription_name']:
return False
return True
def is_appropriate_sub(self, sub):
if sub['activeSubscription'] is False:
return False
@@ -323,9 +301,9 @@ class Licenser(object):
return True
return False
def generate_license_options_from_entitlements(self, json):
from dateutil.parser import parse
ValidSub = collections.namedtuple('ValidSub', 'sku name support_level end_date trial quantity pool_id satellite')
valid_subs = []
for sub in json:
@@ -363,9 +341,7 @@ class Licenser(object):
if attr.get('name') == 'support_level':
support_level = attr.get('value')
valid_subs.append(ValidSub(
sku, sub['productName'], support_level, end_date, trial, quantity, pool_id, satellite
))
valid_subs.append(ValidSub(sku, sub['productName'], support_level, end_date, trial, quantity, pool_id, satellite))
if valid_subs:
licenses = []
@@ -378,40 +354,27 @@ class Licenser(object):
if sub.trial:
license._attrs['trial'] = True
license._attrs['license_type'] = 'trial'
license._attrs['instance_count'] = min(
MAX_INSTANCES, license._attrs['instance_count']
)
license._attrs['instance_count'] = min(MAX_INSTANCES, license._attrs['instance_count'])
human_instances = license._attrs['instance_count']
if human_instances == MAX_INSTANCES:
human_instances = 'Unlimited'
subscription_name = re.sub(
r' \([\d]+ Managed Nodes',
' ({} Managed Nodes'.format(human_instances),
sub.name
)
subscription_name = re.sub(r' \([\d]+ Managed Nodes', ' ({} Managed Nodes'.format(human_instances), sub.name)
license._attrs['subscription_name'] = subscription_name
license._attrs['satellite'] = satellite
license._attrs['valid_key'] = True
license.update(
license_date=int(sub.end_date.strftime('%s'))
)
license.update(
pool_id=sub.pool_id
)
license.update(license_date=int(sub.end_date.strftime('%s')))
license.update(pool_id=sub.pool_id)
licenses.append(license._attrs.copy())
return licenses
raise ValueError(
'No valid Red Hat Ansible Automation subscription could be found for this account.' # noqa
)
raise ValueError('No valid Red Hat Ansible Automation subscription could be found for this account.') # noqa
def validate(self):
# Return license attributes with additional validation info.
attrs = copy.deepcopy(self._attrs)
type = attrs.get('license_type', 'none')
if (type == 'UNLICENSED' or False):
if type == 'UNLICENSED' or False:
attrs.update(dict(valid_key=False, compliant=False))
return attrs
attrs['valid_key'] = True
@@ -422,7 +385,7 @@ class Licenser(object):
current_instances = 0
instance_count = int(attrs.get('instance_count', 0))
attrs['current_instances'] = current_instances
free_instances = (instance_count - current_instances)
free_instances = instance_count - current_instances
attrs['free_instances'] = max(0, free_instances)
license_date = int(attrs.get('license_date', 0) or 0)

View File

@@ -12,8 +12,7 @@ from collections import OrderedDict
logger = logging.getLogger('awx.main.commands.inventory_import')
__all__ = ['MemHost', 'MemGroup', 'MemInventory',
'mem_data_to_dict', 'dict_to_mem_data']
__all__ = ['MemHost', 'MemGroup', 'MemInventory', 'mem_data_to_dict', 'dict_to_mem_data']
ipv6_port_re = re.compile(r'^\[([A-Fa-f0-9:]{3,})\]:(\d+?)$')
@@ -23,9 +22,9 @@ ipv6_port_re = re.compile(r'^\[([A-Fa-f0-9:]{3,})\]:(\d+?)$')
class MemObject(object):
'''
"""
Common code shared between in-memory groups and hosts.
'''
"""
def __init__(self, name):
assert name, 'no name'
@@ -33,9 +32,9 @@ class MemObject(object):
class MemGroup(MemObject):
'''
"""
In-memory representation of an inventory group.
'''
"""
def __init__(self, name):
super(MemGroup, self).__init__(name)
@@ -75,7 +74,7 @@ class MemGroup(MemObject):
logger.debug('Dumping tree for group "%s":', self.name)
logger.debug('- Vars: %r', self.variables)
for h in self.hosts:
logger.debug('- Host: %s, %r', h.name, h.variables)
logger.debug('- Host: %s, %r', h.name, h.variables)
for g in self.children:
logger.debug('- Child: %s', g.name)
logger.debug('----')
@@ -85,9 +84,9 @@ class MemGroup(MemObject):
class MemHost(MemObject):
'''
"""
In-memory representation of an inventory host.
'''
"""
def __init__(self, name, port=None):
super(MemHost, self).__init__(name)
@@ -104,9 +103,10 @@ class MemHost(MemObject):
class MemInventory(object):
'''
"""
Common functions for an inventory loader from a given source.
'''
"""
def __init__(self, all_group=None, group_filter_re=None, host_filter_re=None):
if all_group:
assert isinstance(all_group, MemGroup), '{} is not MemGroup instance'.format(all_group)
@@ -122,10 +122,10 @@ class MemInventory(object):
return host
def get_host(self, name):
'''
"""
Return a MemHost instance from host name, creating if needed. If name
contains brackets, they will NOT be interpreted as a host pattern.
'''
"""
m = ipv6_port_re.match(name)
if m:
host_name = m.groups()[0]
@@ -135,8 +135,7 @@ class MemInventory(object):
try:
port = int(name.split(':')[1])
except (ValueError, UnicodeDecodeError):
logger.warning(u'Invalid port "%s" for host "%s"',
name.split(':')[1], host_name)
logger.warning(u'Invalid port "%s" for host "%s"', name.split(':')[1], host_name)
port = None
else:
host_name = name
@@ -155,9 +154,9 @@ class MemInventory(object):
return group
def get_group(self, name, all_group=None, child=False):
'''
"""
Return a MemGroup instance from group name, creating if needed.
'''
"""
all_group = all_group or self.all_group
if name in ['all', 'ungrouped']:
return all_group
@@ -182,13 +181,14 @@ class MemInventory(object):
# Conversion utilities
def mem_data_to_dict(inventory):
'''
"""
Given an in-memory construct of an inventory, returns a dictionary that
follows Ansible guidelines on the structure of dynamic inventory sources
May be replaced by removing in-memory constructs within this file later
'''
"""
all_group = inventory.all_group
inventory_data = OrderedDict([])
# Save hostvars to _meta
@@ -225,18 +225,18 @@ def mem_data_to_dict(inventory):
def dict_to_mem_data(data, inventory=None):
'''
"""
In-place operation on `inventory`, adds contents from `data` to the
in-memory representation of memory.
May be destructive on `data`
'''
"""
assert isinstance(data, dict), 'Expected dict, received {}'.format(type(data))
if inventory is None:
inventory = MemInventory()
_meta = data.pop('_meta', {})
for k,v in data.items():
for k, v in data.items():
group = inventory.get_group(k)
if not group:
continue
@@ -253,9 +253,7 @@ def dict_to_mem_data(data, inventory=None):
if isinstance(hv, dict):
host.variables.update(hv)
else:
logger.warning('Expected dict of vars for '
'host "%s", got %s instead',
hk, str(type(hv)))
logger.warning('Expected dict of vars for ' 'host "%s", got %s instead', hk, str(type(hv)))
group.add_host(host)
elif isinstance(hosts, (list, tuple)):
for hk in hosts:
@@ -264,17 +262,13 @@ def dict_to_mem_data(data, inventory=None):
continue
group.add_host(host)
else:
logger.warning('Expected dict or list of "hosts" for '
'group "%s", got %s instead', k,
str(type(hosts)))
logger.warning('Expected dict or list of "hosts" for ' 'group "%s", got %s instead', k, str(type(hosts)))
# Process group variables.
vars = v.get('vars', {})
if isinstance(vars, dict):
group.variables.update(vars)
else:
logger.warning('Expected dict of vars for '
'group "%s", got %s instead',
k, str(type(vars)))
logger.warning('Expected dict of vars for ' 'group "%s", got %s instead', k, str(type(vars)))
# Process child groups.
children = v.get('children', [])
if isinstance(children, (list, tuple)):
@@ -283,9 +277,7 @@ def dict_to_mem_data(data, inventory=None):
if child and c != 'ungrouped':
group.add_child_group(child)
else:
logger.warning('Expected list of children for '
'group "%s", got %s instead',
k, str(type(children)))
logger.warning('Expected list of children for ' 'group "%s", got %s instead', k, str(type(children)))
# Load host names from a list.
elif isinstance(v, (list, tuple)):
@@ -296,20 +288,17 @@ def dict_to_mem_data(data, inventory=None):
group.add_host(host)
else:
logger.warning('')
logger.warning('Expected dict or list for group "%s", '
'got %s instead', k, str(type(v)))
logger.warning('Expected dict or list for group "%s", ' 'got %s instead', k, str(type(v)))
if k not in ['all', 'ungrouped']:
inventory.all_group.add_child_group(group)
if _meta:
for k,v in inventory.all_group.all_hosts.items():
for k, v in inventory.all_group.all_hosts.items():
meta_hostvars = _meta['hostvars'].get(k, {})
if isinstance(meta_hostvars, dict):
v.variables.update(meta_hostvars)
else:
logger.warning('Expected dict of vars for '
'host "%s", got %s instead',
k, str(type(meta_hostvars)))
logger.warning('Expected dict of vars for ' 'host "%s", got %s instead', k, str(type(meta_hostvars)))
return inventory

View File

@@ -1,6 +1,7 @@
# Python
import urllib.parse
from collections import deque
# Django
from django.db import models
from django.conf import settings
@@ -16,13 +17,10 @@ for c in ';/?:@=&[]':
FK_NAME = 0
NEXT_NODE = 1
NAME_EXCEPTIONS = {
"custom_inventory_scripts": "inventory_scripts"
}
NAME_EXCEPTIONS = {"custom_inventory_scripts": "inventory_scripts"}
class GraphNode(object):
def __init__(self, model, fields, adj_list):
self.model = model
self.found = False
@@ -50,10 +48,7 @@ class GraphNode(object):
current_fk_name = ''
while stack:
if stack[-1].counter == 0:
named_url_component = NAMED_URL_RES_INNER_DILIMITER.join(
["<%s>" % (current_fk_name + field)
for field in stack[-1].fields]
)
named_url_component = NAMED_URL_RES_INNER_DILIMITER.join(["<%s>" % (current_fk_name + field) for field in stack[-1].fields])
named_url_components.append(named_url_component)
if stack[-1].counter >= len(stack[-1].adj_list):
stack[-1].counter = 0
@@ -73,16 +68,15 @@ class GraphNode(object):
return ret
def _encode_uri(self, text):
'''
"""
Performance assured: http://stackoverflow.com/a/27086669
'''
"""
for c in URL_PATH_RESERVED_CHARSET:
if not isinstance(text, str):
text = str(text) # needed for WFJT node creation, identifier temporarily UUID4 type
if c in text:
text = text.replace(c, URL_PATH_RESERVED_CHARSET[c])
text = text.replace(NAMED_URL_RES_INNER_DILIMITER,
'[%s]' % NAMED_URL_RES_INNER_DILIMITER)
text = text.replace(NAMED_URL_RES_INNER_DILIMITER, '[%s]' % NAMED_URL_RES_INNER_DILIMITER)
return text
def generate_named_url(self, obj):
@@ -91,8 +85,7 @@ class GraphNode(object):
stack = [self]
while stack:
if stack[-1].counter == 0:
named_url_item = [self._encode_uri(getattr(stack[-1].obj, field, ''))
for field in stack[-1].fields]
named_url_item = [self._encode_uri(getattr(stack[-1].obj, field, '')) for field in stack[-1].fields]
named_url.append(NAMED_URL_RES_INNER_DILIMITER.join(named_url_item))
if stack[-1].counter >= len(stack[-1].adj_list):
stack[-1].counter = 0
@@ -109,7 +102,6 @@ class GraphNode(object):
named_url.append('')
return NAMED_URL_RES_DILIMITER.join(named_url)
def _process_top_node(self, named_url_names, kwargs, prefixes, stack, idx):
if stack[-1].counter == 0:
if idx >= len(named_url_names):
@@ -146,16 +138,13 @@ class GraphNode(object):
def populate_named_url_query_kwargs(self, kwargs, named_url, ignore_digits=True):
if ignore_digits and named_url.isdigit() and int(named_url) > 0:
return False
named_url = named_url.replace('[%s]' % NAMED_URL_RES_INNER_DILIMITER,
NAMED_URL_RES_DILIMITER_ENCODE)
named_url = named_url.replace('[%s]' % NAMED_URL_RES_INNER_DILIMITER, NAMED_URL_RES_DILIMITER_ENCODE)
named_url_names = named_url.split(NAMED_URL_RES_DILIMITER)
prefixes = []
stack = [self]
idx = 0
while stack:
idx, is_valid = self._process_top_node(
named_url_names, kwargs, prefixes, stack, idx
)
idx, is_valid = self._process_top_node(named_url_names, kwargs, prefixes, stack, idx)
if not is_valid:
return False
return idx == len(named_url_names)
@@ -192,10 +181,12 @@ def _get_all_unique_togethers(model):
soft_uts = getattr(model_to_backtrack, 'SOFT_UNIQUE_TOGETHER', [])
ret.extend(soft_uts)
for parent_class in model_to_backtrack.__bases__:
if issubclass(parent_class, models.Model) and\
hasattr(parent_class, '_meta') and\
hasattr(parent_class._meta, 'unique_together') and\
isinstance(parent_class._meta.unique_together, tuple):
if (
issubclass(parent_class, models.Model)
and hasattr(parent_class, '_meta')
and hasattr(parent_class._meta, 'unique_together')
and isinstance(parent_class._meta.unique_together, tuple)
):
queue.append(parent_class)
ret.sort(key=lambda x: len(x))
return tuple(ret)
@@ -261,18 +252,11 @@ def _dfs(configuration, model, graph, dead_ends, new_deadends, parents):
next_model = model._meta.get_field(fk_name).related_model
if issubclass(next_model, ContentType):
continue
if next_model not in configuration or\
next_model in dead_ends or\
next_model in new_deadends or\
next_model in parents:
if next_model not in configuration or next_model in dead_ends or next_model in new_deadends or next_model in parents:
new_deadends.add(model)
parents.remove(model)
return False
if next_model not in graph and\
not _dfs(
configuration, next_model, graph,
dead_ends, new_deadends, parents
):
if next_model not in graph and not _dfs(configuration, next_model, graph, dead_ends, new_deadends, parents):
new_deadends.add(model)
parents.remove(model)
return False

View File

@@ -1,4 +1,3 @@
from django.contrib.contenttypes.models import ContentType
from django.db import models

View File

@@ -64,16 +64,18 @@ def timing(name, *init_args, **init_kwargs):
res = func(*args, **kwargs)
timing.stop()
return res
return wrapper_profile
return decorator_profile
class AWXProfiler(AWXProfileBase):
def __init__(self, name, dest='/var/log/tower/profile', dot_enabled=True):
'''
"""
Try to do as little as possible in init. Instead, do the init
only when the profiling is started.
'''
"""
super().__init__(name, dest)
self.started = False
self.dot_enabled = dot_enabled
@@ -101,11 +103,7 @@ class AWXProfiler(AWXProfileBase):
dot_filepath = os.path.join(self.dest, f"{filename_base}.dot")
pstats.Stats(self.prof).dump_stats(raw_filepath)
generate_dot([
'-n', '2.5', '-f', 'pstats', '-o',
dot_filepath,
raw_filepath
])
generate_dot(['-n', '2.5', '-f', 'pstats', '-o', dot_filepath, raw_filepath])
os.remove(raw_filepath)
with open(pstats_filepath, 'w') as f:
@@ -113,7 +111,6 @@ class AWXProfiler(AWXProfileBase):
pstats.Stats(self.prof, stream=f).sort_stats('cumulative').print_stats()
return pstats_filepath
def start(self):
self.prof = cProfile.Profile()
self.pid = os.getpid()
@@ -146,6 +143,7 @@ def profile(name, *init_args, **init_kwargs):
res = func(*args, **kwargs)
prof.stop()
return res
return wrapper_profile
return decorator_profile
return wrapper_profile
return decorator_profile

View File

@@ -11,10 +11,10 @@ logger = logging.getLogger('awx.main.utils.reload')
def supervisor_service_command(command, service='*', communicate=True):
'''
"""
example use pattern of supervisorctl:
# supervisorctl restart tower-processes:receiver tower-processes:factcacher
'''
"""
args = ['supervisorctl']
supervisor_config_path = os.getenv('SUPERVISOR_WEB_CONFIG_PATH', None)
@@ -23,18 +23,18 @@ def supervisor_service_command(command, service='*', communicate=True):
args.extend([command, ':'.join(['tower-processes', service])])
logger.debug('Issuing command to {} services, args={}'.format(command, args))
supervisor_process = subprocess.Popen(args, stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
supervisor_process = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if communicate:
restart_stdout, restart_err = supervisor_process.communicate()
restart_code = supervisor_process.returncode
if restart_code or restart_err:
logger.error('supervisorctl {} {} errored with exit code `{}`, stdout:\n{}stderr:\n{}'.format(
command, service, restart_code, restart_stdout.strip(), restart_err.strip()))
else:
logger.debug(
'supervisorctl {} {} succeeded'.format(command, service)
logger.error(
'supervisorctl {} {} errored with exit code `{}`, stdout:\n{}stderr:\n{}'.format(
command, service, restart_code, restart_stdout.strip(), restart_err.strip()
)
)
else:
logger.debug('supervisorctl {} {} succeeded'.format(command, service))
else:
logger.info('Submitted supervisorctl {} command, not waiting for result'.format(command))

View File

@@ -6,7 +6,6 @@ __all__ = ['safe_dump', 'SafeLoader']
class SafeStringDumper(yaml.SafeDumper):
def represent_data(self, value):
if isinstance(value, str):
return self.represent_scalar('!unsafe', value)
@@ -14,18 +13,15 @@ class SafeStringDumper(yaml.SafeDumper):
class SafeLoader(yaml.Loader):
def construct_yaml_unsafe(self, node):
class UnsafeText(str):
__UNSAFE__ = True
node = UnsafeText(self.construct_scalar(node))
return node
SafeLoader.add_constructor(
u'!unsafe',
SafeLoader.construct_yaml_unsafe
)
SafeLoader.add_constructor(u'!unsafe', SafeLoader.construct_yaml_unsafe)
def safe_dump(x, safe_dict=None):
@@ -41,7 +37,7 @@ def safe_dump(x, safe_dict=None):
resulting YAML. Anything _not_ in this dict will automatically be
`!unsafe`.
safe_dump({'a': 'b', 'c': 'd'}) ->
safe_dump({'a': 'b', 'c': 'd'}) ->
!unsafe 'a': !unsafe 'b'
!unsafe 'c': !unsafe 'd'
@@ -59,12 +55,14 @@ def safe_dump(x, safe_dict=None):
dumper = yaml.SafeDumper
if k not in safe_dict or safe_dict.get(k) != v:
dumper = SafeStringDumper
yamls.append(yaml.dump_all(
[{k: v}],
None,
Dumper=dumper,
default_flow_style=False,
))
yamls.append(
yaml.dump_all(
[{k: v}],
None,
Dumper=dumper,
default_flow_style=False,
)
)
return ''.join(yamls)
else:
return yaml.dump_all([x], None, Dumper=SafeStringDumper, default_flow_style=False)