diff --git a/ext/dapr-ext-grpc/dapr/ext/grpc/_servicer.py b/ext/dapr-ext-grpc/dapr/ext/grpc/_servicer.py index a729fab8e..9694e1c53 100644 --- a/ext/dapr-ext-grpc/dapr/ext/grpc/_servicer.py +++ b/ext/dapr-ext-grpc/dapr/ext/grpc/_servicer.py @@ -81,6 +81,25 @@ def __init__(self): self._registered_topics: List[appcallback_v1.TopicSubscription] = [] self._registered_bindings: List[str] = [] + self._route_map: Dict[Tuple[str, str], TopicSubscribeCallable] = {} + self._validation_disabled_pubsubs: Dict[str, TopicSubscribeCallable] = {} + + def _get_topic_callback( + self, pubsub_name: str, topic: str, path: str + ) -> Optional[TopicSubscribeCallable]: + pubsub_topic = pubsub_name + DELIMITER + topic + DELIMITER + path + if pubsub_topic in self._topic_map: + return self._topic_map[pubsub_topic] + + if (pubsub_name, path) in self._route_map: + return self._route_map[(pubsub_name, path)] + + if path == '': + if pubsub_name in self._validation_disabled_pubsubs: + return self._validation_disabled_pubsubs[pubsub_name] + + return None + def register_method(self, method: str, cb: InvokeMethodCallable) -> None: """Registers method for service invocation.""" if method in self._invoke_method_map: @@ -98,10 +117,7 @@ def register_topic( disable_topic_validation: Optional[bool] = False, ) -> None: """Registers topic subscription for pubsub.""" - if not disable_topic_validation: - topic_key = pubsub_name + DELIMITER + topic - else: - topic_key = pubsub_name + topic_key = pubsub_name + DELIMITER + topic pubsub_topic = topic_key + DELIMITER if rule is not None: path = getattr(cb, '__name__', rule.match) @@ -109,6 +125,10 @@ def register_topic( if pubsub_topic in self._topic_map: raise ValueError(f'{topic} is already registered with {pubsub_name}') self._topic_map[pubsub_topic] = cb + self._route_map[(pubsub_name, topic)] = cb + + if disable_topic_validation: + self._validation_disabled_pubsubs[pubsub_name] = cb registered_topic = self._registered_topics_map.get(topic_key) sub: appcallback_v1.TopicSubscription = appcallback_v1.TopicSubscription() @@ -196,15 +216,10 @@ def ListTopicSubscriptions(self, request, context): def OnTopicEvent(self, request: TopicEventRequest, context): """Subscribes events from Pubsub.""" - pubsub_topic = request.pubsub_name + DELIMITER + request.topic + DELIMITER + request.path - no_validation_key = request.pubsub_name + DELIMITER + request.path - - if pubsub_topic not in self._topic_map: - if no_validation_key in self._topic_map: - pubsub_topic = no_validation_key - else: - context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore - raise NotImplementedError(f'topic {request.topic} is not implemented!') + cb = self._get_topic_callback(request.pubsub_name, request.topic, request.path) + if cb is None: + context.set_code(grpc.StatusCode.UNIMPLEMENTED) # type: ignore + raise NotImplementedError(f'topic {request.topic} is not implemented!') customdata: Struct = request.extensions extensions = dict() @@ -222,7 +237,7 @@ def OnTopicEvent(self, request: TopicEventRequest, context): event.SetSubject(request.topic) event.SetExtensions(extensions) - response = self._topic_map[pubsub_topic](event) + response = cb(event) if isinstance(response, TopicEventResponse): return appcallback_v1.TopicEventResponse(status=response.status.value) return empty_pb2.Empty() @@ -292,15 +307,10 @@ def _handle_bulk_topic_event( self, request: TopicEventBulkRequest, context ) -> Optional[TopicEventBulkResponse]: """Process bulk topic event request - routes each entry to the appropriate topic handler.""" - topic_key = request.pubsub_name + DELIMITER + request.topic + DELIMITER + request.path - no_validation_key = request.pubsub_name + DELIMITER + request.path - - if topic_key not in self._topic_map and no_validation_key not in self._topic_map: + cb = self._get_topic_callback(request.pubsub_name, request.topic, request.path) + if cb is None: return None # we don't have a handler - handler_key = topic_key if topic_key in self._topic_map else no_validation_key - cb = self._topic_map[handler_key] # callback - statuses = [] for entry in request.entries: entry_id = entry.entry_id diff --git a/ext/dapr-ext-grpc/tests/test_servicier.py b/ext/dapr-ext-grpc/tests/test_servicier.py index 1fe18c5dc..ff6910ee0 100644 --- a/ext/dapr-ext-grpc/tests/test_servicier.py +++ b/ext/dapr-ext-grpc/tests/test_servicier.py @@ -182,6 +182,40 @@ def test_non_registered_topic(self): self.fake_context, ) + def test_multiple_wildcard_subscriptions(self): + self._servicer.register_topic( + 'pubsub_multi_wildcard', + 'orders/+/items', + self._topic1_method, + None, + disable_topic_validation=True, + ) + self._servicer.register_topic( + 'pubsub_multi_wildcard', + 'inventory/#', + self._topic2_method, + None, + disable_topic_validation=True, + ) + + self._servicer.OnTopicEvent( + appcallback_v1.TopicEventRequest( + pubsub_name='pubsub_multi_wildcard', topic='orders/123/items', path='orders/+/items' + ), + self.fake_context, + ) + self._topic1_method.assert_called_once() + + self._servicer.OnTopicEvent( + appcallback_v1.TopicEventRequest( + pubsub_name='pubsub_multi_wildcard', + topic='inventory/warehouse/aisle4', + path='inventory/#', + ), + self.fake_context, + ) + self._topic2_method.assert_called_once() + class BulkTopicEventTests(unittest.TestCase): def setUp(self):