Files
awx/awx/main/wsbroadcast.py
Rick Elrod ef29589940 Fix duped stats name and Redis for wsbroadcast
This fixes several things related to our wsbroadcast stats handling.
This was found during the ongoing wsrelay work.

There are really three fixes here:

- Logging was not actually enabled for the analytics.broadcast_websocket
  module, so that has been added to our loggers config.

- analytics.broadcast_websocket was not actually able to connect to
  Redis due to 68614b83c0 as part of
  the work in #13187. But there was no easy way to know this because the
  logging issue meant no exceptions showed up anywhere reasonable.

- Relatedly, and also as part of #13187, we jumped from
  `prometheus-client` 0.7.1 up to 0.15.0. This included a breaking
  change where a `Counter` ending with `_total` will clash with a
  `Gauge` of the same name but without `_total`. I am not 100% sure of
  the reasoning here, other than "OpenMetrics compatibility".

Refs #13301
Refs #13187

Signed-off-by: Rick Elrod <rick@elrod.me>
2022-12-08 12:54:08 -06:00

210 lines
8.1 KiB
Python

import json
import logging
import asyncio
import aiohttp
from aiohttp import client_exceptions
from asgiref.sync import sync_to_async
from channels.layers import get_channel_layer
from django.conf import settings
from django.apps import apps
from django.core.serializers.json import DjangoJSONEncoder
from awx.main.analytics.broadcast_websocket import (
BroadcastWebsocketStats,
BroadcastWebsocketStatsManager,
)
import awx.main.analytics.subsystem_metrics as s_metrics
logger = logging.getLogger('awx.main.wsbroadcast')
def wrap_broadcast_msg(group, message: str):
# TODO: Maybe wrap as "group","message" so that we don't need to
# encode/decode as json.
return json.dumps(dict(group=group, message=message), cls=DjangoJSONEncoder)
def unwrap_broadcast_msg(payload: dict):
return (payload['group'], payload['message'])
@sync_to_async
def get_broadcast_hosts():
Instance = apps.get_model('main', 'Instance')
instances = (
Instance.objects.exclude(hostname=Instance.objects.my_hostname())
.exclude(node_type='execution')
.exclude(node_type='hop')
.order_by('hostname')
.values('hostname', 'ip_address')
.distinct()
)
return {i['hostname']: i['ip_address'] or i['hostname'] for i in instances}
def get_local_host():
Instance = apps.get_model('main', 'Instance')
return Instance.objects.my_hostname()
class WebsocketTask:
def __init__(
self,
name,
event_loop,
stats: BroadcastWebsocketStats,
remote_host: str,
remote_port: int = settings.BROADCAST_WEBSOCKET_PORT,
protocol: str = settings.BROADCAST_WEBSOCKET_PROTOCOL,
verify_ssl: bool = settings.BROADCAST_WEBSOCKET_VERIFY_CERT,
endpoint: str = 'broadcast',
):
self.name = name
self.event_loop = event_loop
self.stats = stats
self.remote_host = remote_host
self.remote_port = remote_port
self.endpoint = endpoint
self.protocol = protocol
self.verify_ssl = verify_ssl
self.channel_layer = None
self.subsystem_metrics = s_metrics.Metrics(instance_name=name)
async def run_loop(self, websocket: aiohttp.ClientWebSocketResponse):
raise RuntimeError("Implement me")
async def connect(self, attempt):
from awx.main.consumers import WebsocketSecretAuthHelper # noqa
logger.debug(f"Connection from {self.name} to {self.remote_host} attempt number {attempt}.")
'''
Can not put get_channel_layer() in the init code because it is in the init
path of channel layers i.e. RedisChannelLayer() calls our init code.
'''
if not self.channel_layer:
self.channel_layer = get_channel_layer()
try:
if attempt > 0:
await asyncio.sleep(settings.BROADCAST_WEBSOCKET_RECONNECT_RETRY_RATE_SECONDS)
except asyncio.CancelledError:
logger.warning(f"Connection from {self.name} to {self.remote_host} cancelled")
raise
uri = f"{self.protocol}://{self.remote_host}:{self.remote_port}/websocket/{self.endpoint}/"
timeout = aiohttp.ClientTimeout(total=10)
secret_val = WebsocketSecretAuthHelper.construct_secret()
try:
async with aiohttp.ClientSession(headers={'secret': secret_val}, timeout=timeout) as session:
async with session.ws_connect(uri, ssl=self.verify_ssl, heartbeat=20) as websocket:
logger.info(f"Connection from {self.name} to {self.remote_host} established.")
self.stats.record_connection_established()
attempt = 0
await self.run_loop(websocket)
except asyncio.CancelledError:
# TODO: Check if connected and disconnect
# Possibly use run_until_complete() if disconnect is async
logger.warning(f"Connection from {self.name} to {self.remote_host} cancelled.")
self.stats.record_connection_lost()
raise
except client_exceptions.ClientConnectorError as e:
logger.warning(f"Connection from {self.name} to {self.remote_host} failed: '{e}'.")
except asyncio.TimeoutError:
logger.warning(f"Connection from {self.name} to {self.remote_host} timed out.")
except Exception as e:
# Early on, this is our canary. I'm not sure what exceptions we can really encounter.
logger.exception(f"Connection from {self.name} to {self.remote_host} failed for unknown reason: '{e}'.")
else:
logger.warning(f"Connection from {self.name} to {self.remote_host} list.")
self.stats.record_connection_lost()
self.start(attempt=attempt + 1)
def start(self, attempt=0):
self.async_task = self.event_loop.create_task(self.connect(attempt=attempt))
def cancel(self):
self.async_task.cancel()
class BroadcastWebsocketTask(WebsocketTask):
async def run_loop(self, websocket: aiohttp.ClientWebSocketResponse):
async for msg in websocket:
self.stats.record_message_received()
if msg.type == aiohttp.WSMsgType.ERROR:
break
elif msg.type == aiohttp.WSMsgType.TEXT:
try:
payload = json.loads(msg.data)
except json.JSONDecodeError:
logmsg = "Failed to decode broadcast message"
if logger.isEnabledFor(logging.DEBUG):
logmsg = "{} {}".format(logmsg, payload)
logger.warning(logmsg)
continue
(group, message) = unwrap_broadcast_msg(payload)
if group == "metrics":
self.subsystem_metrics.store_metrics(message)
continue
await self.channel_layer.group_send(group, {"type": "internal.message", "text": message})
class BroadcastWebsocketManager(object):
def __init__(self):
self.event_loop = asyncio.get_event_loop()
'''
{
'hostname1': BroadcastWebsocketTask(),
'hostname2': BroadcastWebsocketTask(),
'hostname3': BroadcastWebsocketTask(),
}
'''
self.broadcast_tasks = dict()
self.local_hostname = get_local_host()
self.stats_mgr = BroadcastWebsocketStatsManager(self.event_loop, self.local_hostname)
async def run_per_host_websocket(self):
while True:
known_hosts = await get_broadcast_hosts()
future_remote_hosts = known_hosts.keys()
current_remote_hosts = self.broadcast_tasks.keys()
deleted_remote_hosts = set(current_remote_hosts) - set(future_remote_hosts)
new_remote_hosts = set(future_remote_hosts) - set(current_remote_hosts)
remote_addresses = {k: v.remote_host for k, v in self.broadcast_tasks.items()}
for hostname, address in known_hosts.items():
if hostname in self.broadcast_tasks and address != remote_addresses[hostname]:
deleted_remote_hosts.add(hostname)
new_remote_hosts.add(hostname)
if deleted_remote_hosts:
logger.warning(f"Removing {deleted_remote_hosts} from websocket broadcast list")
if new_remote_hosts:
logger.warning(f"Adding {new_remote_hosts} to websocket broadcast list")
for h in deleted_remote_hosts:
self.broadcast_tasks[h].cancel()
del self.broadcast_tasks[h]
self.stats_mgr.delete_remote_host_stats(h)
for h in new_remote_hosts:
stats = self.stats_mgr.new_remote_host_stats(h)
broadcast_task = BroadcastWebsocketTask(name=self.local_hostname, event_loop=self.event_loop, stats=stats, remote_host=known_hosts[h])
broadcast_task.start()
self.broadcast_tasks[h] = broadcast_task
await asyncio.sleep(settings.BROADCAST_WEBSOCKET_NEW_INSTANCE_POLL_RATE_SECONDS)
def start(self):
self.stats_mgr.start()
self.async_task = self.event_loop.create_task(self.run_per_host_websocket())
return self.async_task