From d3b4614b788dda2d4014f43d299168ce2dc2b56d Mon Sep 17 00:00:00 2001 From: michael1011 Date: Fri, 11 Aug 2023 16:37:15 +0200 Subject: [PATCH] fix: skipped hold invoice plugin tests --- tools/hold/consts.py | 12 ++++++++++++ tools/hold/encoder.py | 12 +++++++----- tools/hold/htlc_handler.py | 18 ++++++++++++++++-- tools/hold/plugin.py | 4 ++++ tools/hold/settler.py | 5 ++--- tools/hold/test_plugin.py | 39 ++++++++++++++++++++++++-------------- tools/pyproject.toml | 2 +- 7 files changed, 67 insertions(+), 25 deletions(-) diff --git a/tools/hold/consts.py b/tools/hold/consts.py index 5deb0205..f4c5ba4b 100644 --- a/tools/hold/consts.py +++ b/tools/hold/consts.py @@ -1,4 +1,16 @@ +from enum import Enum + + +class Network(str, Enum): + Mainnet = "bitcoin" + Testnet = "testnet" + Signet = "signet" + Regtest = "regtest" + + PLUGIN_NAME = "hold" TIMEOUT_CANCEL = 60 +TIMEOUT_CANCEL_REGTEST = 5 + TIMEOUT_CHECK_INTERVAL = 10 diff --git a/tools/hold/encoder.py b/tools/hold/encoder.py index 29cc5f92..97358f17 100644 --- a/tools/hold/encoder.py +++ b/tools/hold/encoder.py @@ -3,6 +3,7 @@ from bolt11 import Bolt11, Feature, Features, FeatureState, encode from bolt11.types import MilliSatoshi +from consts import Network from pyln.client import Plugin from secp256k1 import PrivateKey from utils import time_now @@ -11,10 +12,10 @@ NETWORK_PREFIXES = { - "bitcoin": "bc", - "testnet": "tb", - "signet": "tbs", - "regtest": "bcrt", + Network.Mainnet: "bc", + Network.Testnet: "tb", + Network.Signet: "tbs", + Network.Regtest: "bcrt", } @@ -24,8 +25,9 @@ class Defaults(int, Enum): def get_network_prefix(network: str) -> str: + # noinspection PyTypeChecker return NETWORK_PREFIXES[network] if network in NETWORK_PREFIXES \ - else NETWORK_PREFIXES["bitcoin"] + else NETWORK_PREFIXES[Network.Mainnet] def get_payment_secret(val: str | None) -> str: diff --git a/tools/hold/htlc_handler.py b/tools/hold/htlc_handler.py index 09b25cfb..91442d25 100644 --- a/tools/hold/htlc_handler.py +++ b/tools/hold/htlc_handler.py @@ -2,7 +2,12 @@ from typing import Any from urllib.request import Request -from consts import TIMEOUT_CHECK_INTERVAL +from consts import ( + TIMEOUT_CANCEL, + TIMEOUT_CANCEL_REGTEST, + TIMEOUT_CHECK_INTERVAL, + Network, +) from datastore import DataStore from invoice import HoldInvoice, InvoiceState from pyln.client import Plugin @@ -20,9 +25,18 @@ def __init__(self, plugin: Plugin, ds: DataStore, settler: Settler) -> None: self._plugin = plugin self._ds = ds self._settler = settler + self._timeout = TIMEOUT_CANCEL self._start_timeout_interval() + def init(self) -> None: + if self._plugin.rpc.getinfo()["network"] == Network.Regtest: + self._timeout = TIMEOUT_CANCEL_REGTEST + self._plugin.log( + f"Using regtest MPP timeout of {self._timeout} seconds", + level="warn", + ) + def handle_htlc( self, invoice: HoldInvoice, @@ -69,4 +83,4 @@ def _timeout_handler(self) -> None: with self._lock: for htlcs in self._settler.htlcs.values(): if not htlcs.is_fully_paid(): - htlcs.cancel_expired() + htlcs.cancel_expired(self._timeout) diff --git a/tools/hold/plugin.py b/tools/hold/plugin.py index 1b5327d5..ef5fa0e9 100755 --- a/tools/hold/plugin.py +++ b/tools/hold/plugin.py @@ -14,6 +14,8 @@ # TODO: fix shebang line # TODO: restart handling +# TODO: docstrings +# TODO: gRPC with subs pl = Plugin() @@ -31,6 +33,7 @@ def init( plugin: Plugin, **kwargs: dict[str, Any], ) -> None: + handler.init() encoder.init() plugin.log(f"Plugin {PLUGIN_NAME} initialized") @@ -40,6 +43,7 @@ def hold_invoice( plugin: Plugin, payment_hash: str, amount_msat: int, + # TODO: remove default when library can handle empty strings memo: str = "Hold invoice", expiry: int = Defaults.Expiry, min_final_cltv_expiry: int = Defaults.MinFinalCltvExpiry, diff --git a/tools/hold/settler.py b/tools/hold/settler.py index 36f7aa87..0d670718 100644 --- a/tools/hold/settler.py +++ b/tools/hold/settler.py @@ -3,7 +3,6 @@ from enum import Enum from typing import ClassVar -from consts import TIMEOUT_CANCEL from invoice import HoldInvoice, InvoiceState from pyln.client.plugin import Request from utils import partition, time_now @@ -37,12 +36,12 @@ def is_fully_paid(self) -> bool: def requests(self) -> list[Request]: return [h.request for h in self.htlcs] - def cancel_expired(self) -> None: + def cancel_expired(self, expiry: int) -> None: expired, not_expired = partition( self.htlcs, lambda htlc: ( time_now() - htlc.creation_time - ).total_seconds() > TIMEOUT_CANCEL, + ).total_seconds() > expiry, ) self.htlcs = not_expired diff --git a/tools/hold/test_plugin.py b/tools/hold/test_plugin.py index 2a7a4c50..05331854 100644 --- a/tools/hold/test_plugin.py +++ b/tools/hold/test_plugin.py @@ -8,7 +8,9 @@ from threading import Thread from typing import Any +import bolt11 import pytest +from bolt11.types import MilliSatoshi from cli_utils import CliCaller, cln_con PLUGIN_PATH = "/tools/hold/plugin.py" @@ -81,10 +83,12 @@ def __init__( invoice: str, max_shard_size: int | None = None, outgoing_chan_id: str | None = None, + timeout: int | None = None, ) -> None: Thread.__init__(self) self.node = node + self.timeout = timeout self.invoice = invoice self.max_shard_size = max_shard_size self.outgoing_chan_id = outgoing_chan_id @@ -98,6 +102,9 @@ def run(self) -> None: if self.max_shard_size is not None: cmd += f" --max_shard_size_sat {self.max_shard_size}" + if self.timeout is not None: + cmd += f" --timeout {self.timeout}s" + res = lnd_raw(self.node, f"{cmd} {self.invoice} 2> /dev/null") res = res[res.find("{"):] self.res = json.loads(res) @@ -171,7 +178,6 @@ def test_list_not_found(self, cln: CliCaller) -> None: def test_settle_accepted(self, cln: CliCaller) -> None: payment_preimage, payment_hash, invoice = add_hold_invoice(cln) - print(invoice) pay = LndPay(LndNode.One, invoice) pay.start() @@ -294,24 +300,30 @@ def test_cancel_non_existent(self, cln: CliCaller) -> None: assert res["code"] == 2102 assert res["message"] == "hold invoice with that payment hash does not exist" - @pytest.mark.skip() def test_mpp_timeout(self, cln: CliCaller) -> None: _, payment_hash, invoice = add_hold_invoice(cln) - amount = lnd(LndNode.One, "decodepayreq", invoice)["num_satoshis"] - cln_node = cln("getinfo")["id"] + dec = bolt11.decode(invoice) + dec.amount_msat = MilliSatoshi(dec.amount_msat - 1000) - routes = lnd(LndNode.One, "queryroutes", cln_node, str(int(amount) - 1)) + less_invoice = cln("signinvoice", bolt11.encode(dec))["bolt11"] - res = lnd( - LndNode.One, - "sendtoroute", - "--payment_hash", + pay = LndPay(LndNode.One, less_invoice, timeout=5) + pay.start() + + time.sleep(0.5) + assert cln( + "listholdinvoices", payment_hash, - format_json(routes), - ) + )["holdinvoices"][0]["state"] == "unpaid" - assert res["status"] == "FAILED" - assert res["failure"]["code"] == "MPP_TIMEOUT" + pay.join() + + assert pay.res["status"] == "FAILED" + assert pay.res["failure_reason"] == "FAILURE_REASON_TIMEOUT" + assert len(pay.res["htlcs"]) == 1 + + htlc = pay.res["htlcs"][0] + assert htlc["failure"]["code"] == "MPP_TIMEOUT" def test_htlc_too_little_cltv(self, cln: CliCaller) -> None: _, payment_hash, invoice = add_hold_invoice(cln) @@ -422,7 +434,6 @@ def test_ignore_non_hold(self, cln: CliCaller) -> None: assert pay.res["status"] == "SUCCEEDED" - @pytest.mark.skip() def test_ignore_forward(self, cln: CliCaller) -> None: cln_id = cln("getinfo")["id"] channels = lnd(LndNode.Two, "listchannels")["channels"] diff --git a/tools/pyproject.toml b/tools/pyproject.toml index 74b6efa5..d6c48e50 100644 --- a/tools/pyproject.toml +++ b/tools/pyproject.toml @@ -25,5 +25,5 @@ select = ["ALL"] ignore = [ "T201", "D101", "D211", "D213", "INP001", "BLE001", "FBT001", "FBT002", "FBT003", "S605", "TD002", "TD003", "FIX002", "ANN101", "D102", "D103", "D107", "D100", "SLOT000", "S101", - "PLR2004", "ARG001", "PLR0913", "D104" + "PLR2004", "ARG001", "PLR0913", "D104", "FA102" ]