diff --git a/awscrt/mqtt5.py b/awscrt/mqtt5.py index 7c5e4f31f..6768bdf6b 100644 --- a/awscrt/mqtt5.py +++ b/awscrt/mqtt5.py @@ -495,6 +495,27 @@ def _try_puback_reason_code(value): except Exception: return None +class ManualPubackResult(IntEnum): + """Result for a manually invoked PUBACK operation.""" + + SUCCESS = 0 + """The PUBACK was successfully sent.""" + + PUBACK_CANCELLED = 1 + """The PUBACK was cancelled and will not be sent.""" + + PUBACK_INVALID = 2 + """The PUBACK attempting to be sent is invalid.""" + + CRT_FAILURE = 3 + """The PUBACK failed to send due to a CRT failure.""" + + +def _try_manual_puback_result(value): + try: + return ManualPubackResult(value) + except Exception: + return None class SubackReasonCode(IntEnum): """Reason code inside SUBACK packet payloads. @@ -1140,6 +1161,15 @@ class PubackPacket: reason_string: str = None user_properties: 'Sequence[UserProperty]' = None +@dataclass +class InvokePubackCompletion: + """dataclass containing results of a manually invoked PUBACK + + Args: + puback_result (ManualPubackResult): Result of manually invoked PUBACK + """ + puback_result: ManualPubackResult = None + @dataclass class ConnectPacket: @@ -1228,8 +1258,10 @@ class PublishReceivedData: Args: publish_packet (PublishPacket): Data model of an `MQTT5 PUBLISH `_ packet. + acquire_puback_control (Callable): Call this function to prevent automatic PUBACK and take manual control of this PUBLISH message's PUBACK. Returns an opaque handle object that can be passed to Client.invoke_puback(). """ publish_packet: PublishPacket = None + acquire_puback_control: Callable = None @dataclass @@ -1434,7 +1466,8 @@ def _on_publish( correlation_data, subscription_identifiers_tuples, content_type, - user_properties_tuples): + user_properties_tuples, + acquire_puback_control_fn): if self._on_publish_cb is None: return @@ -1468,9 +1501,13 @@ def _on_publish( publish_packet.content_type = content_type publish_packet.user_properties = _init_user_properties(user_properties_tuples) - self._on_publish_cb(PublishReceivedData(publish_packet=publish_packet)) + # Create PublishReceivedData with the manual control callback + publish_data = PublishReceivedData( + publish_packet=publish_packet, + acquire_puback_control=acquire_puback_control_fn + ) - return + self._on_publish_cb(publish_data) def _on_lifecycle_stopped(self): if self._on_lifecycle_stopped_cb: @@ -1957,6 +1994,30 @@ def get_stats(self): result = _awscrt.mqtt5_client_get_stats(self._binding) return OperationStatisticsData(result[0], result[1], result[2], result[3]) + def invoke_puback(self, puback_control_handle): + """Sends a PUBACK packet for the given puback control handle. + + Args: + puback_control_handle: An opaque handle obtained from acquire_puback_control(). This handle cannot be created manually and must come from the acquire_puback_control() Callable within PublishReceivedData. + + Returns: + A future with InvokePubackCompletion that completes when invoked PUBACK is sent or fails to send. A successfully sent PUBACK only confirms the requested PUBACK has been sent, not that the broker has received it and/or it hasn't re-sent the PUBLISH message being acknowledged. + """ + + future = Future() + + def invokePubackComplete(puback_result): + invokePubackCompletion = InvokePubackCompletion(puback_result=_try_manual_puback_result(puback_result)) + future.set_result(invokePubackCompletion) + + _awscrt.mqtt5_client_invoke_puback( + self._binding, + puback_control_handle, + invokePubackComplete + ) + + return future + def new_connection(self, on_connection_interrupted=None, on_connection_resumed=None, on_connection_success=None, on_connection_failure=None, on_connection_closed=None): from awscrt.mqtt import Connection diff --git a/source/module.c b/source/module.c index 0b752e03d..6e2452246 100644 --- a/source/module.c +++ b/source/module.c @@ -810,6 +810,7 @@ static PyMethodDef s_module_methods[] = { AWS_PY_METHOD_DEF(mqtt5_client_subscribe, METH_VARARGS), AWS_PY_METHOD_DEF(mqtt5_client_unsubscribe, METH_VARARGS), AWS_PY_METHOD_DEF(mqtt5_client_get_stats, METH_VARARGS), + AWS_PY_METHOD_DEF(mqtt5_client_invoke_puback, METH_VARARGS), AWS_PY_METHOD_DEF(mqtt5_ws_handshake_transform_complete, METH_VARARGS), /* MQTT Request Response Client */ diff --git a/source/mqtt5_client.c b/source/mqtt5_client.c index 243af6a0e..7e6c0dcea 100644 --- a/source/mqtt5_client.c +++ b/source/mqtt5_client.c @@ -218,6 +218,60 @@ static PyObject *s_aws_set_user_properties_to_PyObject( * Publish Handler ******************************************************************************/ +static const char *s_capsule_name_puback_control_handle = "aws_puback_control_handle"; + +struct puback_control_handle { + uint64_t control_id; +}; + +static void s_puback_control_handle_destructor(PyObject *capsule) { + struct puback_control_handle *handle = PyCapsule_GetPointer(capsule, s_capsule_name_puback_control_handle); + if (handle) { + aws_mem_release(aws_py_get_allocator(), handle); + } +} + +/* Callback context for manual PUBACK control */ +struct manual_puback_control_context { + struct aws_mqtt5_client *client; + struct aws_mqtt5_packet_publish_view *publish_packet; +}; + +static void s_manual_puback_control_context_destructor(PyObject *capsule) { + struct manual_puback_control_context *context = PyCapsule_GetPointer(capsule, "manual_puback_control_context"); + if (context) { + aws_mem_release(aws_py_get_allocator(), context); + } +} + +/* Function called from Python to set manual PUBACK control and return puback_control_id */ +PyObject *aws_py_mqtt5_client_acquire_puback(PyObject *self, PyObject *args) { + (void)args; + + struct manual_puback_control_context *context = PyCapsule_GetPointer(self, "manual_puback_control_context"); + if (!context || !context->publish_packet) { + PyErr_SetString(PyExc_ValueError, "Invalid manual PUBACK control context"); + return NULL; + } + + uint64_t puback_control_id = aws_mqtt5_client_acquire_puback(context->client, context->publish_packet); + + /* Create handle struct */ + struct puback_control_handle *handle = + aws_mem_calloc(aws_py_get_allocator(), 1, sizeof(struct puback_control_handle)); + + handle->control_id = puback_control_id; + + /* Wrap in capsule */ + PyObject *capsule = PyCapsule_New(handle, s_capsule_name_puback_control_handle, s_puback_control_handle_destructor); + if (!capsule) { + aws_mem_release(aws_py_get_allocator(), handle); + return NULL; + } + + return capsule; +} + static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *publish_packet, void *user_data) { if (!user_data) { @@ -234,10 +288,46 @@ static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *pu PyObject *result = NULL; PyObject *subscription_identifier_list = NULL; PyObject *user_properties_list = NULL; + PyObject *manual_control_callback = NULL; + PyObject *control_context_capsule = NULL; size_t subscription_identifier_count = publish_packet->subscription_identifier_count; size_t user_property_count = publish_packet->user_property_count; + /* Create manual PUBACK control context */ + struct manual_puback_control_context *control_context = + aws_mem_calloc(aws_py_get_allocator(), 1, sizeof(struct manual_puback_control_context)); + if (!control_context) { + PyErr_WriteUnraisable(PyErr_Occurred()); + goto cleanup; + } + + /* Set up the context with both client and publish packet */ + control_context->client = client->native; + control_context->publish_packet = (struct aws_mqtt5_packet_publish_view *)publish_packet; + + control_context_capsule = + PyCapsule_New(control_context, "manual_puback_control_context", s_manual_puback_control_context_destructor); + if (!control_context_capsule) { + aws_mem_release(aws_py_get_allocator(), control_context); + PyErr_WriteUnraisable(PyErr_Occurred()); + goto cleanup; + } + + /* Method definition for the manual control callback */ + static PyMethodDef method_def = { + "acquire_puback_control", + aws_py_mqtt5_client_acquire_puback, + METH_NOARGS, + "Take manual control of PUBACK for this message"}; + + /* Create the manual control callback function */ + manual_control_callback = PyCFunction_New(&method_def, control_context_capsule); + if (!manual_control_callback) { + PyErr_WriteUnraisable(PyErr_Occurred()); + goto cleanup; + } + /* Create list of uint32_t subscription identifier tuples */ subscription_identifier_list = PyList_New(subscription_identifier_count); if (!subscription_identifier_list) { @@ -261,7 +351,7 @@ static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *pu result = PyObject_CallMethod( client->client_core, "_on_publish", - "(y#iOs#OiOIOHs#y#Os#O)", + "(y#iOs#OiOIOHs#y#Os#OO)", /* y */ publish_packet->payload.ptr, /* # */ publish_packet->payload.len, /* i */ (int)publish_packet->qos, @@ -284,7 +374,9 @@ static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *pu /* O */ subscription_identifier_count > 0 ? subscription_identifier_list : Py_None, /* s */ publish_packet->content_type ? publish_packet->content_type->ptr : NULL, /* # */ publish_packet->content_type ? publish_packet->content_type->len : 0, - /* O */ user_property_count > 0 ? user_properties_list : Py_None); + /* O */ user_property_count > 0 ? user_properties_list : Py_None, + /* O */ manual_control_callback); + if (!result) { PyErr_WriteUnraisable(PyErr_Occurred()); } @@ -293,6 +385,8 @@ static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *pu Py_XDECREF(result); Py_XDECREF(subscription_identifier_list); Py_XDECREF(user_properties_list); + Py_XDECREF(manual_control_callback); + Py_XDECREF(control_context_capsule); PyGILState_Release(state); } @@ -1683,6 +1777,98 @@ PyObject *aws_py_mqtt5_client_publish(PyObject *self, PyObject *args) { return NULL; } +/******************************************************************************* + * Invoke Puback + ******************************************************************************/ + +struct invoke_puback_complete_userdata { + PyObject *callback; +}; + +static void s_on_invoke_puback_complete_fn(enum aws_mqtt5_manual_puback_result puback_result, void *complete_ctx) { + struct invoke_puback_complete_userdata *metadata = complete_ctx; + assert(metadata); + + PyObject *result = NULL; + + PyGILState_STATE state; + if (aws_py_gilstate_ensure(&state)) { + return; /* Python has shut down. Nothing matters anymore, but don't crash */ + } + + result = PyObject_CallFunction(metadata->callback, "(i)", (int)puback_result); + + if (!result) { + PyErr_WriteUnraisable(PyErr_Occurred()); + } + + // cleanup + Py_XDECREF(metadata->callback); + Py_XDECREF(result); + + PyGILState_Release(state); + + aws_mem_release(aws_py_get_allocator(), metadata); +} + +PyObject *aws_py_mqtt5_client_invoke_puback(PyObject *self, PyObject *args) { + (void)self; + bool success = false; + + PyObject *impl_capsule; + PyObject *puback_handle_capsule; + PyObject *manual_puback_callback_fn_py; + + if (!PyArg_ParseTuple( + args, + "OOO", + /* O */ &impl_capsule, + /* O */ &puback_handle_capsule, + /* O */ &manual_puback_callback_fn_py)) { + return NULL; + } + + struct mqtt5_client_binding *client = PyCapsule_GetPointer(impl_capsule, s_capsule_name_mqtt5_client); + if (!client) { + return NULL; + } + + /* Extract handle from capsule */ + struct puback_control_handle *handle = + PyCapsule_GetPointer(puback_handle_capsule, s_capsule_name_puback_control_handle); + if (!handle) { + PyErr_SetString(PyExc_TypeError, "Invalid PUBACK control handle"); + return NULL; + } + + /* callback related must be cleaned up after this point */ + struct invoke_puback_complete_userdata *metadata = + aws_mem_calloc(aws_py_get_allocator(), 1, sizeof(struct invoke_puback_complete_userdata)); + metadata->callback = manual_puback_callback_fn_py; + Py_INCREF(metadata->callback); + + struct aws_mqtt5_manual_puback_completion_options manual_puback_completion_options = { + .completion_callback = &s_on_invoke_puback_complete_fn, .completion_user_data = metadata}; + + if (aws_mqtt5_client_invoke_puback(client->native, handle->control_id, &manual_puback_completion_options)) { + PyErr_SetAwsLastError(); + goto manual_puback_failed; + } + + success = true; + goto done; + +manual_puback_failed: + Py_XDECREF(manual_puback_callback_fn_py); + aws_mem_release(aws_py_get_allocator(), metadata); + +done: + if (success) { + Py_RETURN_NONE; + } + return NULL; +} + /******************************************************************************* * Subscribe ******************************************************************************/ diff --git a/source/mqtt5_client.h b/source/mqtt5_client.h index 46c135f82..b9bc54f16 100644 --- a/source/mqtt5_client.h +++ b/source/mqtt5_client.h @@ -14,6 +14,7 @@ PyObject *aws_py_mqtt5_client_publish(PyObject *self, PyObject *args); PyObject *aws_py_mqtt5_client_subscribe(PyObject *self, PyObject *args); PyObject *aws_py_mqtt5_client_unsubscribe(PyObject *self, PyObject *args); PyObject *aws_py_mqtt5_client_get_stats(PyObject *self, PyObject *args); +PyObject *aws_py_mqtt5_client_invoke_puback(PyObject *self, PyObject *args); PyObject *aws_py_mqtt5_ws_handshake_transform_complete(PyObject *self, PyObject *args);