diff --git a/examples/consumer.py b/examples/consumer.py index 8c2985e..d698f48 100755 --- a/examples/consumer.py +++ b/examples/consumer.py @@ -29,8 +29,12 @@ }) while True: - msg = consumer.receive() - print("Received message '{0}' id='{1}'".format(msg.data().decode('utf-8'), msg.message_id())) - consumer.acknowledge(msg) + try: + msg = consumer.receive() + print("Received message '{0}' id='{1}'".format(msg.data().decode('utf-8'), msg.message_id())) + consumer.acknowledge(msg) + except pulsar.Interrupted: + print("Stop receiving messages") + break client.close() diff --git a/pulsar/exceptions.py b/pulsar/exceptions.py index d151564..1b425c8 100644 --- a/pulsar/exceptions.py +++ b/pulsar/exceptions.py @@ -25,4 +25,4 @@ ProducerBlockedQuotaExceededException, ProducerQueueIsFull, MessageTooBig, TopicNotFound, SubscriptionNotFound, \ ConsumerNotFound, UnsupportedVersionError, TopicTerminated, CryptoError, IncompatibleSchema, ConsumerAssignError, \ CumulativeAcknowledgementNotAllowedError, TransactionCoordinatorNotFoundError, InvalidTxnStatusError, \ - NotAllowedError, TransactionConflict, TransactionNotFound, ProducerFenced, MemoryBufferIsFull + NotAllowedError, TransactionConflict, TransactionNotFound, ProducerFenced, MemoryBufferIsFull, Interrupted diff --git a/src/client.cc b/src/client.cc index 206c4e2..0103309 100644 --- a/src/client.cc +++ b/src/client.cc @@ -24,73 +24,38 @@ namespace py = pybind11; Producer Client_createProducer(Client& client, const std::string& topic, const ProducerConfiguration& conf) { - Producer producer; - - waitForAsyncValue(std::function([&](CreateProducerCallback callback) { - client.createProducerAsync(topic, conf, callback); - }), - producer); - - return producer; + return waitForAsyncValue( + [&](CreateProducerCallback callback) { client.createProducerAsync(topic, conf, callback); }); } Consumer Client_subscribe(Client& client, const std::string& topic, const std::string& subscriptionName, const ConsumerConfiguration& conf) { - Consumer consumer; - - waitForAsyncValue(std::function([&](SubscribeCallback callback) { - client.subscribeAsync(topic, subscriptionName, conf, callback); - }), - consumer); - - return consumer; + return waitForAsyncValue( + [&](SubscribeCallback callback) { client.subscribeAsync(topic, subscriptionName, conf, callback); }); } Consumer Client_subscribe_topics(Client& client, const std::vector& topics, const std::string& subscriptionName, const ConsumerConfiguration& conf) { - Consumer consumer; - - waitForAsyncValue(std::function([&](SubscribeCallback callback) { - client.subscribeAsync(topics, subscriptionName, conf, callback); - }), - consumer); - - return consumer; + return waitForAsyncValue( + [&](SubscribeCallback callback) { client.subscribeAsync(topics, subscriptionName, conf, callback); }); } Consumer Client_subscribe_pattern(Client& client, const std::string& topic_pattern, const std::string& subscriptionName, const ConsumerConfiguration& conf) { - Consumer consumer; - - waitForAsyncValue(std::function([&](SubscribeCallback callback) { - client.subscribeWithRegexAsync(topic_pattern, subscriptionName, conf, callback); - }), - consumer); - - return consumer; + return waitForAsyncValue([&](SubscribeCallback callback) { + client.subscribeWithRegexAsync(topic_pattern, subscriptionName, conf, callback); + }); } Reader Client_createReader(Client& client, const std::string& topic, const MessageId& startMessageId, const ReaderConfiguration& conf) { - Reader reader; - - waitForAsyncValue(std::function([&](ReaderCallback callback) { - client.createReaderAsync(topic, startMessageId, conf, callback); - }), - reader); - - return reader; + return waitForAsyncValue( + [&](ReaderCallback callback) { client.createReaderAsync(topic, startMessageId, conf, callback); }); } std::vector Client_getTopicPartitions(Client& client, const std::string& topic) { - std::vector partitions; - - waitForAsyncValue(std::function([&](GetPartitionsCallback callback) { - client.getPartitionsForTopicAsync(topic, callback); - }), - partitions); - - return partitions; + return waitForAsyncValue>( + [&](GetPartitionsCallback callback) { client.getPartitionsForTopicAsync(topic, callback); }); } void Client_close(Client& client) { diff --git a/src/consumer.cc b/src/consumer.cc index 972bd0b..4b44775 100644 --- a/src/consumer.cc +++ b/src/consumer.cc @@ -29,13 +29,7 @@ void Consumer_unsubscribe(Consumer& consumer) { } Message Consumer_receive(Consumer& consumer) { - Message msg; - - waitForAsyncValue(std::function( - [&consumer](ReceiveCallback callback) { consumer.receiveAsync(callback); }), - msg); - - return msg; + return waitForAsyncValue([&](ReceiveCallback callback) { consumer.receiveAsync(callback); }); } Message Consumer_receive_timeout(Consumer& consumer, int timeoutMs) { @@ -59,32 +53,27 @@ Messages Consumer_batch_receive(Consumer& consumer) { void Consumer_acknowledge(Consumer& consumer, const Message& msg) { consumer.acknowledgeAsync(msg, nullptr); } void Consumer_acknowledge_message_id(Consumer& consumer, const MessageId& msgId) { - Py_BEGIN_ALLOW_THREADS - consumer.acknowledgeAsync(msgId, nullptr); + Py_BEGIN_ALLOW_THREADS consumer.acknowledgeAsync(msgId, nullptr); Py_END_ALLOW_THREADS } void Consumer_negative_acknowledge(Consumer& consumer, const Message& msg) { - Py_BEGIN_ALLOW_THREADS - consumer.negativeAcknowledge(msg); + Py_BEGIN_ALLOW_THREADS consumer.negativeAcknowledge(msg); Py_END_ALLOW_THREADS } void Consumer_negative_acknowledge_message_id(Consumer& consumer, const MessageId& msgId) { - Py_BEGIN_ALLOW_THREADS - consumer.negativeAcknowledge(msgId); + Py_BEGIN_ALLOW_THREADS consumer.negativeAcknowledge(msgId); Py_END_ALLOW_THREADS } void Consumer_acknowledge_cumulative(Consumer& consumer, const Message& msg) { - Py_BEGIN_ALLOW_THREADS - consumer.acknowledgeCumulativeAsync(msg, nullptr); + Py_BEGIN_ALLOW_THREADS consumer.acknowledgeCumulativeAsync(msg, nullptr); Py_END_ALLOW_THREADS } void Consumer_acknowledge_cumulative_message_id(Consumer& consumer, const MessageId& msgId) { - Py_BEGIN_ALLOW_THREADS - consumer.acknowledgeCumulativeAsync(msgId, nullptr); + Py_BEGIN_ALLOW_THREADS consumer.acknowledgeCumulativeAsync(msgId, nullptr); Py_END_ALLOW_THREADS } diff --git a/src/future.h b/src/future.h deleted file mode 100644 index 6754c89..0000000 --- a/src/future.h +++ /dev/null @@ -1,181 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#ifndef LIB_FUTURE_H_ -#define LIB_FUTURE_H_ - -#include -#include -#include -#include - -#include - -typedef std::unique_lock Lock; - -namespace pulsar { - -template -struct InternalState { - std::mutex mutex; - std::condition_variable condition; - Result result; - Type value; - bool complete; - - std::list > listeners; -}; - -template -class Future { - public: - typedef std::function ListenerCallback; - - Future& addListener(ListenerCallback callback) { - InternalState* state = state_.get(); - Lock lock(state->mutex); - - if (state->complete) { - lock.unlock(); - callback(state->result, state->value); - } else { - state->listeners.push_back(callback); - } - - return *this; - } - - Result get(Type& result) { - InternalState* state = state_.get(); - Lock lock(state->mutex); - - if (!state->complete) { - // Wait for result - while (!state->complete) { - state->condition.wait(lock); - } - } - - result = state->value; - return state->result; - } - - template - bool get(Result& res, Type& value, Duration d) { - InternalState* state = state_.get(); - Lock lock(state->mutex); - - if (!state->complete) { - // Wait for result - while (!state->complete) { - if (!state->condition.wait_for(lock, d, [&state] { return state->complete; })) { - // Timeout while waiting for the future to complete - return false; - } - } - } - - value = state->value; - res = state->result; - return true; - } - - private: - typedef std::shared_ptr > InternalStatePtr; - Future(InternalStatePtr state) : state_(state) {} - - std::shared_ptr > state_; - - template - friend class Promise; -}; - -template -class Promise { - public: - Promise() : state_(std::make_shared >()) {} - - bool setValue(const Type& value) const { - static Result DEFAULT_RESULT; - InternalState* state = state_.get(); - Lock lock(state->mutex); - - if (state->complete) { - return false; - } - - state->value = value; - state->result = DEFAULT_RESULT; - state->complete = true; - - decltype(state->listeners) listeners; - listeners.swap(state->listeners); - - lock.unlock(); - - for (auto& callback : listeners) { - callback(DEFAULT_RESULT, value); - } - - state->condition.notify_all(); - return true; - } - - bool setFailed(Result result) const { - static Type DEFAULT_VALUE; - InternalState* state = state_.get(); - Lock lock(state->mutex); - - if (state->complete) { - return false; - } - - state->result = result; - state->complete = true; - - decltype(state->listeners) listeners; - listeners.swap(state->listeners); - - lock.unlock(); - - for (auto& callback : listeners) { - callback(result, DEFAULT_VALUE); - } - - state->condition.notify_all(); - return true; - } - - bool isComplete() const { - InternalState* state = state_.get(); - Lock lock(state->mutex); - return state->complete; - } - - Future getFuture() const { return Future(state_); } - - private: - typedef std::function ListenerCallback; - std::shared_ptr > state_; -}; - -class Void {}; - -} /* namespace pulsar */ - -#endif /* LIB_FUTURE_H_ */ diff --git a/src/producer.cc b/src/producer.cc index 1dd5a76..7027185 100644 --- a/src/producer.cc +++ b/src/producer.cc @@ -25,21 +25,15 @@ namespace py = pybind11; MessageId Producer_send(Producer& producer, const Message& message) { - MessageId messageId; - - waitForAsyncValue(std::function( - [&](SendCallback callback) { producer.sendAsync(message, callback); }), - messageId); - - return messageId; + return waitForAsyncValue( + [&](SendCallback callback) { producer.sendAsync(message, callback); }); } void Producer_sendAsync(Producer& producer, const Message& msg, SendCallback callback) { - Py_BEGIN_ALLOW_THREADS - producer.sendAsync(msg, callback); + Py_BEGIN_ALLOW_THREADS producer.sendAsync(msg, callback); Py_END_ALLOW_THREADS - if (PyErr_CheckSignals() == -1) { + if (PyErr_CheckSignals() == -1) { PyErr_SetInterrupt(); } } diff --git a/src/reader.cc b/src/reader.cc index 7194c29..0126f3f 100644 --- a/src/reader.cc +++ b/src/reader.cc @@ -62,14 +62,8 @@ Message Reader_readNextTimeout(Reader& reader, int timeoutMs) { } bool Reader_hasMessageAvailable(Reader& reader) { - bool available = false; - - waitForAsyncValue( - std::function( - [&](HasMessageAvailableCallback callback) { reader.hasMessageAvailableAsync(callback); }), - available); - - return available; + return waitForAsyncValue( + [&](HasMessageAvailableCallback callback) { reader.hasMessageAvailableAsync(callback); }); } void Reader_close(Reader& reader) { diff --git a/src/utils.cc b/src/utils.cc index cf8f6f4..8ebc3f9 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -20,28 +20,29 @@ #include "utils.h" void waitForAsyncResult(std::function func) { - Result res = ResultOk; - bool b; - Promise promise; - Future future = promise.getFuture(); + auto promise = std::make_shared>(); + func([promise](Result result) { promise->set_value(result); }); + internal::waitForResult(*promise); +} - Py_BEGIN_ALLOW_THREADS func(WaitForCallback(promise)); - Py_END_ALLOW_THREADS +namespace internal { - bool isComplete; +void waitForResult(std::promise& promise) { + auto future = promise.get_future(); while (true) { - // Check periodically for Python signals - Py_BEGIN_ALLOW_THREADS isComplete = future.get(b, std::ref(res), std::chrono::milliseconds(100)); - Py_END_ALLOW_THREADS - - if (isComplete) { - CHECK_RESULT(res); - return; + { + py::gil_scoped_release release; + auto status = future.wait_for(std::chrono::milliseconds(100)); + if (status == std::future_status::ready) { + CHECK_RESULT(future.get()); + return; + } } - - if (PyErr_CheckSignals() == -1) { - PyErr_SetInterrupt(); - return; + py::gil_scoped_acquire acquire; + if (PyErr_CheckSignals() != 0) { + raiseException(ResultInterrupted); } } } + +} // namespace internal diff --git a/src/utils.h b/src/utils.h index fb700c6..bbe202e 100644 --- a/src/utils.h +++ b/src/utils.h @@ -21,12 +21,14 @@ #include #include +#include #include -#include +#include +#include #include "exceptions.h" -#include "future.h" using namespace pulsar; +namespace py = pybind11; inline void CHECK_RESULT(Result res) { if (res != ResultOk) { @@ -34,56 +36,26 @@ inline void CHECK_RESULT(Result res) { } } -struct WaitForCallback { - Promise m_promise; +namespace internal { - WaitForCallback(Promise promise) : m_promise(promise) {} +void waitForResult(std::promise& promise); - void operator()(Result result) { m_promise.setValue(result); } -}; - -template -struct WaitForCallbackValue { - Promise& m_promise; - - WaitForCallbackValue(Promise& promise) : m_promise(promise) {} - - void operator()(Result result, const T& value) { - if (result == ResultOk) { - m_promise.setValue(value); - } else { - m_promise.setFailed(result); - } - } -}; +} // namespace internal void waitForAsyncResult(std::function func); -template -inline void waitForAsyncValue(std::function func, T& value) { - Result res = ResultOk; - Promise promise; - Future future = promise.getFuture(); - - Py_BEGIN_ALLOW_THREADS func(WaitForCallbackValue(promise)); - Py_END_ALLOW_THREADS +template +inline T waitForAsyncValue(std::function)> func) { + auto resultPromise = std::make_shared>(); + auto valuePromise = std::make_shared>(); - bool isComplete; - while (true) { - // Check periodically for Python signals - Py_BEGIN_ALLOW_THREADS isComplete = future.get(res, std::ref(value), std::chrono::milliseconds(100)); - Py_END_ALLOW_THREADS + func([resultPromise, valuePromise](Result result, const T& value) { + valuePromise->set_value(value); + resultPromise->set_value(result); + }); - if (isComplete) { - CHECK_RESULT(res); - return; - } - - if (PyErr_CheckSignals() == -1) { - PyErr_SetInterrupt(); - return; - } - } + internal::waitForResult(*resultPromise); + return valuePromise->get_future().get(); } struct CryptoKeyReaderWrapper { diff --git a/tests/interrupted_test.py b/tests/interrupted_test.py new file mode 100644 index 0000000..6d61f99 --- /dev/null +++ b/tests/interrupted_test.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from unittest import TestCase, main +import pulsar +import signal +import time +import threading + +class InterruptedTest(TestCase): + + service_url = 'pulsar://localhost:6650' + + def test_sigint(self): + def thread_function(): + time.sleep(1) + signal.raise_signal(signal.SIGINT) + + client = pulsar.Client(self.service_url) + consumer = client.subscribe('test-sigint', "my-sub") + thread = threading.Thread(target=thread_function) + thread.start() + + start = time.time() + with self.assertRaises(pulsar.Interrupted): + consumer.receive() + finish = time.time() + print(f"time: {finish - start}") + self.assertGreater(finish - start, 1) + self.assertLess(finish - start, 1.5) + client.close() + +if __name__ == '__main__': + main() diff --git a/tests/run-unit-tests.sh b/tests/run-unit-tests.sh index 13349f9..5168f94 100755 --- a/tests/run-unit-tests.sh +++ b/tests/run-unit-tests.sh @@ -24,4 +24,5 @@ ROOT_DIR=$(git rev-parse --show-toplevel) cd $ROOT_DIR/tests python3 custom_logger_test.py +python3 interrupted_test.py python3 pulsar_test.py