Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 31 additions & 21 deletions ext/dapr-ext-grpc/dapr/ext/grpc/_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,25 @@ def __init__(self):
self._registered_topics: List[appcallback_v1.TopicSubscription] = []
self._registered_bindings: List[str] = []

self._route_map: Dict[Tuple[str, str], TopicSubscribeCallable] = {}
self._validation_disabled_pubsubs: Dict[str, TopicSubscribeCallable] = {}

def _get_topic_callback(
self, pubsub_name: str, topic: str, path: str
) -> Optional[TopicSubscribeCallable]:
pubsub_topic = pubsub_name + DELIMITER + topic + DELIMITER + path
if pubsub_topic in self._topic_map:
return self._topic_map[pubsub_topic]

if (pubsub_name, path) in self._route_map:
return self._route_map[(pubsub_name, path)]

if path == '':
if pubsub_name in self._validation_disabled_pubsubs:
return self._validation_disabled_pubsubs[pubsub_name]

return None

def register_method(self, method: str, cb: InvokeMethodCallable) -> None:
"""Registers method for service invocation."""
if method in self._invoke_method_map:
Expand All @@ -98,17 +117,18 @@ def register_topic(
disable_topic_validation: Optional[bool] = False,
) -> None:
"""Registers topic subscription for pubsub."""
if not disable_topic_validation:
topic_key = pubsub_name + DELIMITER + topic
else:
topic_key = pubsub_name
topic_key = pubsub_name + DELIMITER + topic
pubsub_topic = topic_key + DELIMITER
if rule is not None:
path = getattr(cb, '__name__', rule.match)
pubsub_topic = pubsub_topic + path
if pubsub_topic in self._topic_map:
raise ValueError(f'{topic} is already registered with {pubsub_name}')
self._topic_map[pubsub_topic] = cb
self._route_map[(pubsub_name, topic)] = cb

if disable_topic_validation:
self._validation_disabled_pubsubs[pubsub_name] = cb

registered_topic = self._registered_topics_map.get(topic_key)
sub: appcallback_v1.TopicSubscription = appcallback_v1.TopicSubscription()
Expand Down Expand Up @@ -196,15 +216,10 @@ def ListTopicSubscriptions(self, request, context):

def OnTopicEvent(self, request: TopicEventRequest, context):
"""Subscribes events from Pubsub."""
pubsub_topic = request.pubsub_name + DELIMITER + request.topic + DELIMITER + request.path
no_validation_key = request.pubsub_name + DELIMITER + request.path

if pubsub_topic not in self._topic_map:
if no_validation_key in self._topic_map:
pubsub_topic = no_validation_key
else:
context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore
raise NotImplementedError(f'topic {request.topic} is not implemented!')
cb = self._get_topic_callback(request.pubsub_name, request.topic, request.path)
if cb is None:
context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore
raise NotImplementedError(f'topic {request.topic} is not implemented!')

customdata: Struct = request.extensions
extensions = dict()
Expand All @@ -222,7 +237,7 @@ def OnTopicEvent(self, request: TopicEventRequest, context):
event.SetSubject(request.topic)
event.SetExtensions(extensions)

response = self._topic_map[pubsub_topic](event)
response = cb(event)
if isinstance(response, TopicEventResponse):
return appcallback_v1.TopicEventResponse(status=response.status.value)
return empty_pb2.Empty()
Expand Down Expand Up @@ -292,15 +307,10 @@ def _handle_bulk_topic_event(
self, request: TopicEventBulkRequest, context
) -> Optional[TopicEventBulkResponse]:
"""Process bulk topic event request - routes each entry to the appropriate topic handler."""
topic_key = request.pubsub_name + DELIMITER + request.topic + DELIMITER + request.path
no_validation_key = request.pubsub_name + DELIMITER + request.path

if topic_key not in self._topic_map and no_validation_key not in self._topic_map:
cb = self._get_topic_callback(request.pubsub_name, request.topic, request.path)
if cb is None:
return None # we don't have a handler

handler_key = topic_key if topic_key in self._topic_map else no_validation_key
cb = self._topic_map[handler_key] # callback

statuses = []
for entry in request.entries:
entry_id = entry.entry_id
Expand Down
34 changes: 34 additions & 0 deletions ext/dapr-ext-grpc/tests/test_servicier.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,40 @@ def test_non_registered_topic(self):
self.fake_context,
)

def test_multiple_wildcard_subscriptions(self):
self._servicer.register_topic(
'pubsub_multi_wildcard',
'orders/+/items',
self._topic1_method,
None,
disable_topic_validation=True,
)
self._servicer.register_topic(
'pubsub_multi_wildcard',
'inventory/#',
self._topic2_method,
None,
disable_topic_validation=True,
)

self._servicer.OnTopicEvent(
appcallback_v1.TopicEventRequest(
pubsub_name='pubsub_multi_wildcard', topic='orders/123/items', path='orders/+/items'
),
self.fake_context,
)
self._topic1_method.assert_called_once()

self._servicer.OnTopicEvent(
appcallback_v1.TopicEventRequest(
pubsub_name='pubsub_multi_wildcard',
topic='inventory/warehouse/aisle4',
path='inventory/#',
),
self.fake_context,
)
self._topic2_method.assert_called_once()


class BulkTopicEventTests(unittest.TestCase):
def setUp(self):
Expand Down