Skip to content
Merged
19 changes: 11 additions & 8 deletions BlockServerToKafka/block_server_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
# http://opensource.org/licenses/eclipse-1.0.php
import json
from threading import RLock
from typing import Any

import ca
from CaChannel import CaChannel, CaChannelException

from BlockServer.core.macros import BLOCK_PREFIX
from BlockServerToKafka.kafka_producer import ProducerWrapper
from server_common.utilities import dehex_and_decompress, print_and_log


Expand All @@ -30,7 +32,7 @@ class BlockServerMonitor:
Uses a Channel Access Monitor.
"""

def __init__(self, address, pvprefix, producer):
def __init__(self, address: str, pvprefix: str, producer: ProducerWrapper) -> None:
self.PVPREFIX = pvprefix
self.address = address
self.channel = CaChannel()
Expand All @@ -45,15 +47,15 @@ def __init__(self, address, pvprefix, producer):

# Create the CA monitor callback
self.channel.add_masked_array_event(
ca.dbf_type_to_DBR_STS(self.channel.field_type()),
ca.dbf_type_to_DBR_STS(self.channel.field_type()), # pyright: ignore
0,
ca.DBE_VALUE,
ca.DBE_VALUE, # pyright: ignore
self.update,
None,
)
self.channel.pend_event()

def block_name_to_pv_name(self, blk):
def block_name_to_pv_name(self, blk: str) -> str:
"""
Converts a block name to a PV name by adding the prefixes.

Expand All @@ -66,11 +68,12 @@ def block_name_to_pv_name(self, blk):
return f"{self.PVPREFIX}{BLOCK_PREFIX}{blk}"

@staticmethod
def convert_to_string(pv_array):
def convert_to_string(pv_array: bytearray) -> str:
"""
Convert from byte array to string and remove null characters.

We cannot get the number of elements in the array so convert to bytes and remove the null characters.
We cannot get the number of elements in the array so convert to bytes and remove the
null characters.

Args:
pv_array (bytearray): The byte array of PVs.
Expand All @@ -81,7 +84,7 @@ def convert_to_string(pv_array):

return bytearray(pv_array).decode("utf-8").replace("\x00", "")

def update_config(self, blocks):
def update_config(self, blocks: list[str]) -> None:
"""
Updates the forwarder configuration to monitor the supplied blocks.

Expand All @@ -99,7 +102,7 @@ def update_config(self, blocks):
self.producer.add_config(pvs)
self.last_pvs = pvs

def update(self, epics_args, user_args):
def update(self, epics_args: dict[str, bytearray], user_args: Any) -> None: # noqa: ANN401
"""
Updates the kafka config when the blockserver changes. This is called from the monitor.

Expand Down
21 changes: 13 additions & 8 deletions BlockServerToKafka/forwarder_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,39 @@
# http://opensource.org/licenses/eclipse-1.0.php
from typing import List

from streaming_data_types.fbschemas.forwarder_config_update_rf5k.Protocol import (
from streaming_data_types.fbschemas.forwarder_config_update_fc00.Protocol import (
Protocol,
)
from streaming_data_types.fbschemas.forwarder_config_update_rf5k.UpdateType import (
from streaming_data_types.fbschemas.forwarder_config_update_fc00.UpdateType import (
UpdateType,
)
from streaming_data_types.forwarder_config_update_rf5k import StreamInfo, serialise_rf5k
from streaming_data_types.forwarder_config_update_fc00 import StreamInfo, serialise_fc00


class ForwarderConfig:
"""
Class that converts the pv information to a forwarder config message payload
"""

def __init__(self, topic: str, epics_protocol: Protocol = Protocol.CA, schema: str = "f142"):
def __init__(
self,
topic: str,
epics_protocol: Protocol = Protocol.CA, # pyright: ignore
schema: str = "f144",
) -> None:
self.schema = schema
self.topic = topic
self.epics_protocol = epics_protocol

def _create_streams(self, pvs: List[str]) -> List[StreamInfo]:
return [StreamInfo(pv, self.schema, self.topic, self.epics_protocol) for pv in pvs]
return [StreamInfo(pv, self.schema, self.topic, self.epics_protocol, 0) for pv in pvs] # pyright: ignore

def create_forwarder_configuration(self, pvs: List[str]) -> bytes:
return serialise_rf5k(UpdateType.ADD, self._create_streams(pvs))
return serialise_fc00(UpdateType.ADD, self._create_streams(pvs)) # pyright: ignore

def remove_forwarder_configuration(self, pvs: List[str]) -> bytes:
return serialise_rf5k(UpdateType.REMOVE, self._create_streams(pvs))
return serialise_fc00(UpdateType.REMOVE, self._create_streams(pvs)) # pyright: ignore

@staticmethod
def remove_all_forwarder_configuration() -> bytes:
return serialise_rf5k(UpdateType.REMOVEALL, [])
return serialise_fc00(UpdateType.REMOVEALL, []) # pyright: ignore
46 changes: 27 additions & 19 deletions BlockServerToKafka/kafka_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
from typing import List

from kafka import KafkaConsumer, KafkaProducer, errors
from streaming_data_types.fbschemas.forwarder_config_update_rf5k.Protocol import (
from streaming_data_types.fbschemas.forwarder_config_update_fc00.Protocol import (
Protocol,
)

from BlockServerToKafka.forwarder_config import ForwarderConfig
from server_common.utilities import print_and_log
from server_common.utilities import SEVERITY, print_and_log


class ProducerWrapper:
Expand All @@ -35,39 +35,47 @@ def __init__(
server: str,
config_topic: str,
data_topic: str,
epics_protocol: Protocol = Protocol.CA,
):
epics_protocol: Protocol = Protocol.CA, # pyright: ignore
) -> None:
self.topic = config_topic
self.converter = ForwarderConfig(data_topic, epics_protocol)
self._set_up_producer(server)
while not self._set_up_producer(server):
print_and_log("Failed to create producer, retrying in 30s")
sleep(30)

def _set_up_producer(self, server: str):
def _set_up_producer(self, server: str) -> bool:
"""
Attempts to create a Kafka producer and consumer. Retries with a recursive call every 30s.
"""
try:
self.client = KafkaConsumer(bootstrap_servers=server)
self.producer = KafkaProducer(bootstrap_servers=server)
if not self.topic_exists(self.topic):
print_and_log(
f"WARNING: topic {self.topic} does not exist. It will be created by default."
)
return True
except errors.NoBrokersAvailable:
print_and_log(f"No brokers found on server: {server[0]}")
except errors.ConnectionError:
print_and_log("No server found, connection error")
print_and_log(f"No brokers found on server: {server[0]}", severity=SEVERITY.MAJOR)
except errors.KafkaConnectionError:
print_and_log("No server found, connection error", severity=SEVERITY.MAJOR)
except errors.InvalidConfigurationError:
print_and_log("Invalid configuration")
print_and_log("Invalid configuration", severity=SEVERITY.MAJOR)
quit()
except errors.InvalidTopicError:
print_and_log(
"Invalid topic, to enable auto creation of topics set"
" auto.create.topics.enable to false in broker configuration"
" auto.create.topics.enable to false in broker configuration",
severity=SEVERITY.MAJOR,
)
except Exception as e:
print_and_log(
f"Unexpected error while creating producer or consumer: {str(e)}",
severity=SEVERITY.MAJOR,
)
finally:
print_and_log("Retrying in 10s")
sleep(10)
# Recursive call after waiting
self._set_up_producer(server)
return False

def add_config(self, pvs: List[str]):
def add_config(self, pvs: List[str]) -> None:
"""
Create a forwarder configuration to add more pvs to be monitored.

Expand All @@ -79,7 +87,7 @@ def add_config(self, pvs: List[str]):
def topic_exists(self, topic_name: str) -> bool:
return topic_name in self.client.topics()

def remove_config(self, pvs: List[str]):
def remove_config(self, pvs: List[str]) -> None:
"""
Create a forwarder configuration to remove pvs that are being monitored.

Expand All @@ -88,7 +96,7 @@ def remove_config(self, pvs: List[str]):
message_buffer = self.converter.remove_forwarder_configuration(pvs)
self.producer.send(self.topic, message_buffer)

def stop_all_pvs(self):
def stop_all_pvs(self) -> None:
"""
Sends a stop_all command to the forwarder to clear all configuration.
"""
Expand Down
28 changes: 0 additions & 28 deletions BlockServerToKafka/test_modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +0,0 @@
from __future__ import absolute_import, division, print_function, unicode_literals

# This file is part of the ISIS IBEX application.
# Copyright (C) 2012-2020 Science & Technology Facilities Council.
# All rights reserved.
#
# This program is distributed in the hope that it will be useful.
# This program and the accompanying materials are made available under the
# terms of the Eclipse Public License v1.0 which accompanies this distribution.
# EXCEPT AS EXPRESSLY SET FORTH IN THE ECLIPSE PUBLIC LICENSE V1.0, THE PROGRAM
# AND ACCOMPANYING MATERIALS ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES
# OR CONDITIONS OF ANY KIND. See the Eclipse Public License v1.0 for more details.
#
# You should have received a copy of the Eclipse Public License v1.0
# along with this program; if not, you can obtain a copy from
# https://www.eclipse.org/org/documents/epl-v10.php or
# http://opensource.org/licenses/eclipse-1.0.php
import os


def load_tests(loader, standard_tests, pattern):
"""
This function is needed by the load_tests protocol described at
https://docs.python.org/3/library/unittest.html#load-tests-protocol
The tests in this module are only added under Python 3.
"""
standard_tests.addTests(loader.discover(os.path.dirname(__file__), pattern=pattern))
return standard_tests
54 changes: 32 additions & 22 deletions BlockServerToKafka/test_modules/test_block_server_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,61 +21,71 @@
from BlockServerToKafka.block_server_monitor import BlockServerMonitor


@patch("CaChannel.CaChannel")
class TestBlockServerMonitor(unittest.TestCase):
test_address = "TEST_ADDRESS"
test_prefix = "TEST_PREFIX"

@patch("CaChannel.CaChannel")
@patch("CaChannel.CaChannel.searchw")
@patch("CaChannel.CaChannel.add_masked_array_event")
@patch("CaChannel.CaChannel.field_type")
@patch("CaChannel.CaChannel.pend_event")
def setUp(self, mock_ca_channel, mock_search, mock_add_array, mock_field_type, mock_pend_event):
def setUp(self):
self.mock_producer = MagicMock()
self.bs_monitor = BlockServerMonitor(
self.test_address, self.test_prefix, self.mock_producer
)

def test_WHEN_convert_one_char_to_string_THEN_returns_character(self):
def test_WHEN_convert_one_char_to_string_THEN_returns_character(
self,
mock_ca_channel,
):
c = "a"
arr = [ord(c)]
self.assertEqual(c, self.bs_monitor.convert_to_string(arr))
self.assertEqual(c, self.bs_monitor.convert_to_string(bytearray(arr)))

def test_WHEN_convert_many_chars_to_string_THEN_returns_characters(self):
def test_WHEN_convert_many_chars_to_string_THEN_returns_characters(self, mock_ca_channel):
chars = "hello world"
arr = [ord(c) for c in chars]
self.assertEqual(chars, self.bs_monitor.convert_to_string(arr))
self.assertEqual(chars, self.bs_monitor.convert_to_string(bytearray(arr)))

def test_WHEN_convert_chars_with_null_at_end_THEN_nulls_removed(self):
def test_WHEN_convert_chars_with_null_at_end_THEN_nulls_removed(
self,
mock_ca_channel,
):
chars = "hello world"
arr = [ord(c) for c in chars]
for i in range(3):
arr.append(0)
self.assertEqual(chars, self.bs_monitor.convert_to_string(arr))
self.assertEqual(chars, self.bs_monitor.convert_to_string(bytearray(arr)))

def test_WHEN_convert_chars_with_null_at_start_THEN_nulls_removed(self):
def test_WHEN_convert_chars_with_null_at_start_THEN_nulls_removed(
self,
mock_ca_channel,
):
chars = "hello world"
arr = [ord(c) for c in chars]
for i in range(3):
arr.insert(0, 0)
self.assertEqual(chars, self.bs_monitor.convert_to_string(arr))
self.assertEqual(chars, self.bs_monitor.convert_to_string(bytearray(arr)))

def test_WHEN_convert_chars_with_nulls_in_centre_THEN_nulls_removed(self):
def test_WHEN_convert_chars_with_nulls_in_centre_THEN_nulls_removed(self, mock_ca_channel):
chars = "hello world"
arr = [ord(c) for c in chars]
arr.insert(4, 0)
self.assertEqual(chars, self.bs_monitor.convert_to_string(arr))
self.assertEqual(chars, self.bs_monitor.convert_to_string(bytearray(arr)))

def test_WHEN_convert_nulls_THEN_empty_string_returned(self):
def test_WHEN_convert_nulls_THEN_empty_string_returned(
self,
mock_ca_channel,
):
arr = [0] * 10
self.assertEqual("", self.bs_monitor.convert_to_string(arr))
self.assertEqual("", self.bs_monitor.convert_to_string(bytearray(arr)))

def test_GIVEN_no_previous_pvs_WHEN_update_config_called_THEN_producer_is_called(self):
def test_GIVEN_no_previous_pvs_WHEN_update_config_called_THEN_producer_is_called(
self, mock_ca_channel
):
self.bs_monitor.update_config(["BLOCK"])
self.mock_producer.add_config.assert_called_once()

def test_GIVEN_no_previous_pvs_WHEN_update_config_called_THEN_producer_is_called_containing_block_name(
self,
self, mock_ca_channel
):
block = "BLOCK"
self.bs_monitor.update_config([block])
Expand All @@ -84,15 +94,15 @@ def test_GIVEN_no_previous_pvs_WHEN_update_config_called_THEN_producer_is_called
)

def test_GIVEN_previous_pvs_WHEN_update_config_called_with_same_pvs_THEN_producer_is_not_called(
self,
self, mock_ca_channel
):
block = "BLOCK"
self.bs_monitor.update_config([block])
self.bs_monitor.update_config([block])
self.mock_producer.add_config.assert_called_once()

def test_GIVEN_previous_pvs_WHEN_update_config_called_with_different_pvs_THEN_producer_is_called(
self,
self, mock_ca_channel
):
self.bs_monitor.update_config(["OLD_BLOCK"])
self.mock_producer.reset_mock()
Expand Down
Loading
Loading