From 84e01f777139c911a90c86b8301ac8a111cbc119 Mon Sep 17 00:00:00 2001 From: michael1011 Date: Mon, 14 Aug 2023 16:37:38 +0200 Subject: [PATCH] feat: routing hints in hold invoice plugin --- .github/workflows/docker-publish.yml | 9 +-- docker/build.py | 2 +- docker/regtest/scripts/setup.sh | 24 +++++-- docker/regtest/startRegtest.sh | 2 +- tools/hold/encoder.py | 31 +++++---- tools/hold/hold.py | 9 +++ tools/hold/plugin.py | 19 +++++- tools/hold/protos/hold.proto | 38 ++++++++--- tools/hold/protos/hold_pb2.py | 68 +++++++++++--------- tools/hold/protos/hold_pb2.pyi | 85 ++++++++++++++++++++----- tools/hold/protos/hold_pb2_grpc.py | 95 ++++++++++++++++++++-------- tools/hold/route_hints.py | 31 +++++++++ tools/hold/server.py | 40 +++++++----- tools/hold/test_encoder.py | 64 ++++++++++++++++--- tools/hold/test_grpc.py | 87 ++++++++++++++++++++++++- tools/hold/test_plugin.py | 76 +++++++++++++++++++++- tools/hold/test_route_hints.py | 35 ++++++++++ tools/hold/test_utils.py | 27 +++++++- tools/hold/transformers.py | 81 ++++++++++++++++++++++++ 19 files changed, 683 insertions(+), 140 deletions(-) create mode 100644 tools/hold/route_hints.py create mode 100644 tools/hold/test_route_hints.py diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml index b38adc5a..143b5633 100644 --- a/.github/workflows/docker-publish.yml +++ b/.github/workflows/docker-publish.yml @@ -9,19 +9,20 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - name: Check out code + uses: actions/checkout@v3 - name: Set up QEMU - uses: docker/setup-qemu-action@v1 + uses: docker/setup-qemu-action@v2 - name: Set up Docker Buildx id: buildx - uses: docker/setup-buildx-action@v1 + uses: docker/setup-buildx-action@v2 with: install: true - name: Login to GitHub Container Registry - uses: docker/login-action@v1 + uses: docker/login-action@v2 with: registry: ghcr.io username: ${{ github.repository_owner }} diff --git a/docker/build.py b/docker/build.py index d291f406..b36a3284 100755 --- a/docker/build.py +++ b/docker/build.py @@ -101,7 +101,7 @@ class Image: ], ), "regtest": Image( - tags=["4.0.0"], + tags=["4.0.1"], arguments=[ UBUNTU_VERSION, BITCOIN_BUILD_ARG, diff --git a/docker/regtest/scripts/setup.sh b/docker/regtest/scripts/setup.sh index 3806ef40..90c54200 100755 --- a/docker/regtest/scripts/setup.sh +++ b/docker/regtest/scripts/setup.sh @@ -23,8 +23,13 @@ function openChannel () { waitForLndToSync "$2" - $2 connect $3@127.0.0.1:$4 > /dev/null - $2 openchannel --node_key $3 --local_amt 100000000 --push_amt 50000000 > /dev/null + $2 connect $3@127.0.0.1:$4 > /dev/null 2> /dev/null + + if $5; then + $2 openchannel --node_key $3 --local_amt 100000000 --push_amt 50000000 --private > /dev/null + else + $2 openchannel --node_key $3 --local_amt 100000000 --push_amt 50000000 > /dev/null + fi $1 generatetoaddress 6 ${nodeAddress} > /dev/null @@ -54,8 +59,8 @@ function waitForClnChannel () { sleep 25 } -echo "/tools/.venv/bin/python3 /tools/hold/plugin.py" > /root/hold-start.sh -chmod +x /root/hold-start.sh +echo "/tools/.venv/bin/python3 /tools/hold/plugin.py" > /root/hold.sh +chmod +x /root/hold.sh startNodes @@ -80,18 +85,23 @@ echo "Opening BTC channels" openChannel bitcoin-cli \ "lncli --lnddir=/root/.lnd-btc --rpcserver=127.0.0.1:10009 --network=regtest" \ $(lncli --lnddir=/root/.lnd-btc --rpcserver=127.0.0.1:10011 --network=regtest getinfo | jq -r '.identity_pubkey') \ - 9736 + 9736 false echo "Opened channel to LND" openChannel bitcoin-cli \ "lncli --lnddir=/root/.lnd-btc --rpcserver=127.0.0.1:10009 --network=regtest" \ $(lightning-cli getinfo | jq -r .id) \ - 9737 + 9737 false + +openChannel bitcoin-cli \ + "lncli --lnddir=/root/.lnd-btc --rpcserver=127.0.0.1:10009 --network=regtest" \ + $(lightning-cli getinfo | jq -r .id) \ + 9737 true openChannel bitcoin-cli \ "lncli --lnddir=/root/.lnd-btc --rpcserver=127.0.0.1:10011 --network=regtest" \ $(lightning-cli getinfo | jq -r .id) \ - 9737 + 9737 false echo "Opened channels to CLN" diff --git a/docker/regtest/startRegtest.sh b/docker/regtest/startRegtest.sh index c53ee89a..80d7e45f 100755 --- a/docker/regtest/startRegtest.sh +++ b/docker/regtest/startRegtest.sh @@ -34,4 +34,4 @@ docker run \ -p 31000:31000 \ -p 31001:31001 \ -p 31002:31002 \ - boltz/regtest:4.0.0 + boltz/regtest:4.0.1 diff --git a/tools/hold/encoder.py b/tools/hold/encoder.py index 845f172e..5a0b7eb1 100644 --- a/tools/hold/encoder.py +++ b/tools/hold/encoder.py @@ -2,15 +2,12 @@ from enum import Enum from bolt11 import Bolt11, Feature, Features, FeatureState, encode -from bolt11.types import MilliSatoshi, Tag, TagChar, Tags +from bolt11.types import MilliSatoshi, RouteHint, Tag, TagChar, Tags from consts import Network from pyln.client import Plugin from secp256k1 import PrivateKey from utils import time_now -# TODO: routing hints - - NETWORK_PREFIXES = { Network.Mainnet: "bc", Network.Testnet: "tb", @@ -71,21 +68,27 @@ def encode( expiry: int = Defaults.Expiry, min_final_cltv_expiry: int = Defaults.MinFinalCltvExpiry, payment_secret: str | None = None, + route_hints: list[RouteHint] | None = None, ) -> str: + tags = Tags( + [ + Tag(TagChar.payment_hash, payment_hash), + Tag(TagChar.description, description), + Tag(TagChar.expire_time, expiry), + Tag(TagChar.min_final_cltv_expiry, min_final_cltv_expiry), + Tag(TagChar.payment_secret, get_payment_secret(payment_secret)), + Tag(TagChar.features, self._features), + ] + ) + + if route_hints is not None: + tags.tags.extend([Tag(TagChar.route_hint, route) for route in route_hints]) + return encode( Bolt11( self._prefix, int(time_now().timestamp()), - Tags( - [ - Tag(TagChar.payment_hash, payment_hash), - Tag(TagChar.description, description), - Tag(TagChar.expire_time, expiry), - Tag(TagChar.min_final_cltv_expiry, min_final_cltv_expiry), - Tag(TagChar.payment_secret, get_payment_secret(payment_secret)), - Tag(TagChar.features, self._features), - ] - ), + tags, MilliSatoshi(amount_msat), ), self._key, diff --git a/tools/hold/hold.py b/tools/hold/hold.py index 3b7daee5..8631db7f 100644 --- a/tools/hold/hold.py +++ b/tools/hold/hold.py @@ -1,10 +1,12 @@ import hashlib +from bolt11.types import RouteHint from datastore import DataErrorCodes, DataStore from encoder import Encoder from htlc_handler import HtlcHandler from invoice import HoldInvoice, InvoiceState from pyln.client import Plugin, RpcError +from route_hints import RouteHints from settler import Settler from tracker import Tracker @@ -23,6 +25,7 @@ def __init__(self, plugin: Plugin) -> None: self.tracker = Tracker() self._settler = Settler(self.tracker) self._encoder = Encoder(plugin) + self._route_hints = RouteHints(plugin) self.ds = DataStore(plugin, self._settler) self.handler = HtlcHandler(plugin, self.ds, self._settler, self.tracker) @@ -30,6 +33,7 @@ def __init__(self, plugin: Plugin) -> None: def init(self) -> None: self.handler.init() self._encoder.init() + self._route_hints.init() def invoice( self, @@ -38,6 +42,7 @@ def invoice( description: str, expiry: int, min_final_cltv_expiry: int, + route_hints: list[RouteHint] | None = None, ) -> str: if ( len(self._plugin.rpc.listinvoices(payment_hash=payment_hash)["invoices"]) @@ -51,6 +56,7 @@ def invoice( description, expiry, min_final_cltv_expiry, + route_hints=route_hints, ) signed = self._plugin.rpc.call( "signinvoice", @@ -103,3 +109,6 @@ def wipe(self, payment_hash: str | None) -> int: return 1 raise NoSuchInvoiceError + + def get_private_channels(self, node: str) -> list[RouteHint]: + return self._route_hints.get_private_channels(node) diff --git a/tools/hold/plugin.py b/tools/hold/plugin.py index 58d8cf2c..86f75a1b 100755 --- a/tools/hold/plugin.py +++ b/tools/hold/plugin.py @@ -10,12 +10,12 @@ from pyln.client.plugin import Request from server import Server from settler import HtlcFailureMessage, Settler +from transformers import Transformers from hold import Hold, InvoiceExistsError, NoSuchInvoiceError # TODO: restart handling # TODO: docstrings -# TODO: command to get private channels for pl = Plugin() hold = Hold(pl) @@ -50,11 +50,19 @@ def hold_invoice( description: str = "", expiry: int = Defaults.Expiry, min_final_cltv_expiry: int = Defaults.MinFinalCltvExpiry, + routing_hints: list[Any] | None = None, ) -> dict[str, Any]: try: return { "bolt11": hold.invoice( - payment_hash, amount_msat, description, expiry, min_final_cltv_expiry + payment_hash, + amount_msat, + description, + expiry, + min_final_cltv_expiry, + Transformers.routing_hints_from_json(routing_hints) + if routing_hints is not None + else None, ), } except InvoiceExistsError: @@ -92,6 +100,13 @@ def cancel_hold_invoice(plugin: Plugin, payment_hash: str) -> dict[str, Any]: return {} +@pl.method("routinghints") +def get_routing_hints(plugin: Plugin, node: str) -> dict[str, Any]: + return { + "hints": Transformers.named_tuples_to_dict(hold.get_private_channels(node)), + } + + @pl.method("dev-wipeholdinvoices") def wipe_hold_invoices(plugin: Plugin, payment_hash: str = "") -> dict[str, Any]: try: diff --git a/tools/hold/protos/hold.proto b/tools/hold/protos/hold.proto index 97a4f26b..ce514321 100644 --- a/tools/hold/protos/hold.proto +++ b/tools/hold/protos/hold.proto @@ -4,9 +4,11 @@ package hold; service Hold { rpc Invoice (InvoiceRequest) returns (InvoiceResponse) {} + rpc RoutingHints (RoutingHintsRequest) returns (RoutingHintsResponse) {} + rpc List (ListRequest) returns (ListResponse) {} + rpc Settle (SettleRequest) returns (SettleResponse) {} rpc Cancel (CancelRequest) returns (CancelResponse) {} - rpc List (ListRequest) returns (ListResponse) {} rpc Track (TrackRequest) returns (stream TrackResponse) {} rpc TrackAll (TrackAllRequest) returns (stream TrackAllResponse) {} @@ -18,20 +20,31 @@ message InvoiceRequest { optional string description = 3; optional uint64 expiry = 4; optional uint64 min_final_cltv_expiry = 5; + repeated RoutingHint routing_hints = 6; } message InvoiceResponse { string bolt11 = 1; } -message SettleRequest { - string payment_preimage = 1; +message RoutingHintsRequest { + string node = 1; } -message SettleResponse {} -message CancelRequest { - string payment_hash = 1; +message Hop { + string public_key = 1; + string short_channel_id = 2; + uint64 base_fee = 3; + uint64 ppm_fee = 4; + uint64 cltv_expiry_delta = 5; +} + +message RoutingHint { + repeated Hop hops = 1; +} + +message RoutingHintsResponse { + repeated RoutingHint hints = 1; } -message CancelResponse {} message ListRequest { optional string payment_hash = 1; @@ -55,10 +68,19 @@ message ListResponse { repeated Invoice invoices = 1; } -message TrackRequest { +message SettleRequest { + string payment_preimage = 1; +} +message SettleResponse {} + +message CancelRequest { string payment_hash = 1; } +message CancelResponse {} +message TrackRequest { + string payment_hash = 1; +} message TrackResponse { InvoiceState state = 1; } diff --git a/tools/hold/protos/hold_pb2.py b/tools/hold/protos/hold_pb2.py index 666a83c0..948d3cd8 100644 --- a/tools/hold/protos/hold_pb2.py +++ b/tools/hold/protos/hold_pb2.py @@ -13,7 +13,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\nhold.proto\x12\x04hold"\xc3\x01\n\x0eInvoiceRequest\x12\x14\n\x0cpayment_hash\x18\x01 \x01(\t\x12\x13\n\x0b\x61mount_msat\x18\x02 \x01(\x04\x12\x18\n\x0b\x64\x65scription\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06\x65xpiry\x18\x04 \x01(\x04H\x01\x88\x01\x01\x12"\n\x15min_final_cltv_expiry\x18\x05 \x01(\x04H\x02\x88\x01\x01\x42\x0e\n\x0c_descriptionB\t\n\x07_expiryB\x18\n\x16_min_final_cltv_expiry"!\n\x0fInvoiceResponse\x12\x0e\n\x06\x62olt11\x18\x01 \x01(\t")\n\rSettleRequest\x12\x18\n\x10payment_preimage\x18\x01 \x01(\t"\x10\n\x0eSettleResponse"%\n\rCancelRequest\x12\x14\n\x0cpayment_hash\x18\x01 \x01(\t"\x10\n\x0e\x43\x61ncelResponse"9\n\x0bListRequest\x12\x19\n\x0cpayment_hash\x18\x01 \x01(\tH\x00\x88\x01\x01\x42\x0f\n\r_payment_hash"\x86\x01\n\x07Invoice\x12\x14\n\x0cpayment_hash\x18\x01 \x01(\t\x12\x1d\n\x10payment_preimage\x18\x02 \x01(\tH\x00\x88\x01\x01\x12!\n\x05state\x18\x03 \x01(\x0e\x32\x12.hold.InvoiceState\x12\x0e\n\x06\x62olt11\x18\x04 \x01(\tB\x13\n\x11_payment_preimage"/\n\x0cListResponse\x12\x1f\n\x08invoices\x18\x01 \x03(\x0b\x32\r.hold.Invoice"$\n\x0cTrackRequest\x12\x14\n\x0cpayment_hash\x18\x01 \x01(\t"2\n\rTrackResponse\x12!\n\x05state\x18\x01 \x01(\x0e\x32\x12.hold.InvoiceState"\x11\n\x0fTrackAllRequest"K\n\x10TrackAllResponse\x12\x14\n\x0cpayment_hash\x18\x01 \x01(\t\x12!\n\x05state\x18\x02 \x01(\x0e\x32\x12.hold.InvoiceState*]\n\x0cInvoiceState\x12\x11\n\rInvoiceUnpaid\x10\x00\x12\x13\n\x0fInvoiceAccepted\x10\x01\x12\x0f\n\x0bInvoicePaid\x10\x02\x12\x14\n\x10InvoiceCancelled\x10\x03\x32\xd4\x02\n\x04Hold\x12\x38\n\x07Invoice\x12\x14.hold.InvoiceRequest\x1a\x15.hold.InvoiceResponse"\x00\x12\x35\n\x06Settle\x12\x13.hold.SettleRequest\x1a\x14.hold.SettleResponse"\x00\x12\x35\n\x06\x43\x61ncel\x12\x13.hold.CancelRequest\x1a\x14.hold.CancelResponse"\x00\x12/\n\x04List\x12\x11.hold.ListRequest\x1a\x12.hold.ListResponse"\x00\x12\x34\n\x05Track\x12\x12.hold.TrackRequest\x1a\x13.hold.TrackResponse"\x00\x30\x01\x12=\n\x08TrackAll\x12\x15.hold.TrackAllRequest\x1a\x16.hold.TrackAllResponse"\x00\x30\x01\x62\x06proto3' + b'\n\nhold.proto\x12\x04hold"\xed\x01\n\x0eInvoiceRequest\x12\x14\n\x0cpayment_hash\x18\x01 \x01(\t\x12\x13\n\x0b\x61mount_msat\x18\x02 \x01(\x04\x12\x18\n\x0b\x64\x65scription\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x13\n\x06\x65xpiry\x18\x04 \x01(\x04H\x01\x88\x01\x01\x12"\n\x15min_final_cltv_expiry\x18\x05 \x01(\x04H\x02\x88\x01\x01\x12(\n\rrouting_hints\x18\x06 \x03(\x0b\x32\x11.hold.RoutingHintB\x0e\n\x0c_descriptionB\t\n\x07_expiryB\x18\n\x16_min_final_cltv_expiry"!\n\x0fInvoiceResponse\x12\x0e\n\x06\x62olt11\x18\x01 \x01(\t"#\n\x13RoutingHintsRequest\x12\x0c\n\x04node\x18\x01 \x01(\t"q\n\x03Hop\x12\x12\n\npublic_key\x18\x01 \x01(\t\x12\x18\n\x10short_channel_id\x18\x02 \x01(\t\x12\x10\n\x08\x62\x61se_fee\x18\x03 \x01(\x04\x12\x0f\n\x07ppm_fee\x18\x04 \x01(\x04\x12\x19\n\x11\x63ltv_expiry_delta\x18\x05 \x01(\x04"&\n\x0bRoutingHint\x12\x17\n\x04hops\x18\x01 \x03(\x0b\x32\t.hold.Hop"8\n\x14RoutingHintsResponse\x12 \n\x05hints\x18\x01 \x03(\x0b\x32\x11.hold.RoutingHint"9\n\x0bListRequest\x12\x19\n\x0cpayment_hash\x18\x01 \x01(\tH\x00\x88\x01\x01\x42\x0f\n\r_payment_hash"\x86\x01\n\x07Invoice\x12\x14\n\x0cpayment_hash\x18\x01 \x01(\t\x12\x1d\n\x10payment_preimage\x18\x02 \x01(\tH\x00\x88\x01\x01\x12!\n\x05state\x18\x03 \x01(\x0e\x32\x12.hold.InvoiceState\x12\x0e\n\x06\x62olt11\x18\x04 \x01(\tB\x13\n\x11_payment_preimage"/\n\x0cListResponse\x12\x1f\n\x08invoices\x18\x01 \x03(\x0b\x32\r.hold.Invoice")\n\rSettleRequest\x12\x18\n\x10payment_preimage\x18\x01 \x01(\t"\x10\n\x0eSettleResponse"%\n\rCancelRequest\x12\x14\n\x0cpayment_hash\x18\x01 \x01(\t"\x10\n\x0e\x43\x61ncelResponse"$\n\x0cTrackRequest\x12\x14\n\x0cpayment_hash\x18\x01 \x01(\t"2\n\rTrackResponse\x12!\n\x05state\x18\x01 \x01(\x0e\x32\x12.hold.InvoiceState"\x11\n\x0fTrackAllRequest"K\n\x10TrackAllResponse\x12\x14\n\x0cpayment_hash\x18\x01 \x01(\t\x12!\n\x05state\x18\x02 \x01(\x0e\x32\x12.hold.InvoiceState*]\n\x0cInvoiceState\x12\x11\n\rInvoiceUnpaid\x10\x00\x12\x13\n\x0fInvoiceAccepted\x10\x01\x12\x0f\n\x0bInvoicePaid\x10\x02\x12\x14\n\x10InvoiceCancelled\x10\x03\x32\x9d\x03\n\x04Hold\x12\x38\n\x07Invoice\x12\x14.hold.InvoiceRequest\x1a\x15.hold.InvoiceResponse"\x00\x12G\n\x0cRoutingHints\x12\x19.hold.RoutingHintsRequest\x1a\x1a.hold.RoutingHintsResponse"\x00\x12/\n\x04List\x12\x11.hold.ListRequest\x1a\x12.hold.ListResponse"\x00\x12\x35\n\x06Settle\x12\x13.hold.SettleRequest\x1a\x14.hold.SettleResponse"\x00\x12\x35\n\x06\x43\x61ncel\x12\x13.hold.CancelRequest\x1a\x14.hold.CancelResponse"\x00\x12\x34\n\x05Track\x12\x12.hold.TrackRequest\x1a\x13.hold.TrackResponse"\x00\x30\x01\x12=\n\x08TrackAll\x12\x15.hold.TrackAllRequest\x1a\x16.hold.TrackAllResponse"\x00\x30\x01\x62\x06proto3' ) _globals = globals() @@ -21,34 +21,42 @@ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "hold_pb2", _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - _globals["_INVOICESTATE"]._serialized_start = 802 - _globals["_INVOICESTATE"]._serialized_end = 895 + _globals["_INVOICESTATE"]._serialized_start = 1094 + _globals["_INVOICESTATE"]._serialized_end = 1187 _globals["_INVOICEREQUEST"]._serialized_start = 21 - _globals["_INVOICEREQUEST"]._serialized_end = 216 - _globals["_INVOICERESPONSE"]._serialized_start = 218 - _globals["_INVOICERESPONSE"]._serialized_end = 251 - _globals["_SETTLEREQUEST"]._serialized_start = 253 - _globals["_SETTLEREQUEST"]._serialized_end = 294 - _globals["_SETTLERESPONSE"]._serialized_start = 296 - _globals["_SETTLERESPONSE"]._serialized_end = 312 - _globals["_CANCELREQUEST"]._serialized_start = 314 - _globals["_CANCELREQUEST"]._serialized_end = 351 - _globals["_CANCELRESPONSE"]._serialized_start = 353 - _globals["_CANCELRESPONSE"]._serialized_end = 369 - _globals["_LISTREQUEST"]._serialized_start = 371 - _globals["_LISTREQUEST"]._serialized_end = 428 - _globals["_INVOICE"]._serialized_start = 431 - _globals["_INVOICE"]._serialized_end = 565 - _globals["_LISTRESPONSE"]._serialized_start = 567 - _globals["_LISTRESPONSE"]._serialized_end = 614 - _globals["_TRACKREQUEST"]._serialized_start = 616 - _globals["_TRACKREQUEST"]._serialized_end = 652 - _globals["_TRACKRESPONSE"]._serialized_start = 654 - _globals["_TRACKRESPONSE"]._serialized_end = 704 - _globals["_TRACKALLREQUEST"]._serialized_start = 706 - _globals["_TRACKALLREQUEST"]._serialized_end = 723 - _globals["_TRACKALLRESPONSE"]._serialized_start = 725 - _globals["_TRACKALLRESPONSE"]._serialized_end = 800 - _globals["_HOLD"]._serialized_start = 898 - _globals["_HOLD"]._serialized_end = 1238 + _globals["_INVOICEREQUEST"]._serialized_end = 258 + _globals["_INVOICERESPONSE"]._serialized_start = 260 + _globals["_INVOICERESPONSE"]._serialized_end = 293 + _globals["_ROUTINGHINTSREQUEST"]._serialized_start = 295 + _globals["_ROUTINGHINTSREQUEST"]._serialized_end = 330 + _globals["_HOP"]._serialized_start = 332 + _globals["_HOP"]._serialized_end = 445 + _globals["_ROUTINGHINT"]._serialized_start = 447 + _globals["_ROUTINGHINT"]._serialized_end = 485 + _globals["_ROUTINGHINTSRESPONSE"]._serialized_start = 487 + _globals["_ROUTINGHINTSRESPONSE"]._serialized_end = 543 + _globals["_LISTREQUEST"]._serialized_start = 545 + _globals["_LISTREQUEST"]._serialized_end = 602 + _globals["_INVOICE"]._serialized_start = 605 + _globals["_INVOICE"]._serialized_end = 739 + _globals["_LISTRESPONSE"]._serialized_start = 741 + _globals["_LISTRESPONSE"]._serialized_end = 788 + _globals["_SETTLEREQUEST"]._serialized_start = 790 + _globals["_SETTLEREQUEST"]._serialized_end = 831 + _globals["_SETTLERESPONSE"]._serialized_start = 833 + _globals["_SETTLERESPONSE"]._serialized_end = 849 + _globals["_CANCELREQUEST"]._serialized_start = 851 + _globals["_CANCELREQUEST"]._serialized_end = 888 + _globals["_CANCELRESPONSE"]._serialized_start = 890 + _globals["_CANCELRESPONSE"]._serialized_end = 906 + _globals["_TRACKREQUEST"]._serialized_start = 908 + _globals["_TRACKREQUEST"]._serialized_end = 944 + _globals["_TRACKRESPONSE"]._serialized_start = 946 + _globals["_TRACKRESPONSE"]._serialized_end = 996 + _globals["_TRACKALLREQUEST"]._serialized_start = 998 + _globals["_TRACKALLREQUEST"]._serialized_end = 1015 + _globals["_TRACKALLRESPONSE"]._serialized_start = 1017 + _globals["_TRACKALLRESPONSE"]._serialized_end = 1092 + _globals["_HOLD"]._serialized_start = 1190 + _globals["_HOLD"]._serialized_end = 1603 # @@protoc_insertion_point(module_scope) diff --git a/tools/hold/protos/hold_pb2.pyi b/tools/hold/protos/hold_pb2.pyi index 64234bb4..bb25cfc1 100644 --- a/tools/hold/protos/hold_pb2.pyi +++ b/tools/hold/protos/hold_pb2.pyi @@ -31,17 +31,20 @@ class InvoiceRequest(_message.Message): "description", "expiry", "min_final_cltv_expiry", + "routing_hints", ] PAYMENT_HASH_FIELD_NUMBER: _ClassVar[int] AMOUNT_MSAT_FIELD_NUMBER: _ClassVar[int] DESCRIPTION_FIELD_NUMBER: _ClassVar[int] EXPIRY_FIELD_NUMBER: _ClassVar[int] MIN_FINAL_CLTV_EXPIRY_FIELD_NUMBER: _ClassVar[int] + ROUTING_HINTS_FIELD_NUMBER: _ClassVar[int] payment_hash: str amount_msat: int description: str expiry: int min_final_cltv_expiry: int + routing_hints: _containers.RepeatedCompositeFieldContainer[RoutingHint] def __init__( self, payment_hash: _Optional[str] = ..., @@ -49,6 +52,7 @@ class InvoiceRequest(_message.Message): description: _Optional[str] = ..., expiry: _Optional[int] = ..., min_final_cltv_expiry: _Optional[int] = ..., + routing_hints: _Optional[_Iterable[_Union[RoutingHint, _Mapping]]] = ..., ) -> None: ... class InvoiceResponse(_message.Message): @@ -57,25 +61,54 @@ class InvoiceResponse(_message.Message): bolt11: str def __init__(self, bolt11: _Optional[str] = ...) -> None: ... -class SettleRequest(_message.Message): - __slots__ = ["payment_preimage"] - PAYMENT_PREIMAGE_FIELD_NUMBER: _ClassVar[int] - payment_preimage: str - def __init__(self, payment_preimage: _Optional[str] = ...) -> None: ... +class RoutingHintsRequest(_message.Message): + __slots__ = ["node"] + NODE_FIELD_NUMBER: _ClassVar[int] + node: str + def __init__(self, node: _Optional[str] = ...) -> None: ... -class SettleResponse(_message.Message): - __slots__ = [] - def __init__(self) -> None: ... +class Hop(_message.Message): + __slots__ = [ + "public_key", + "short_channel_id", + "base_fee", + "ppm_fee", + "cltv_expiry_delta", + ] + PUBLIC_KEY_FIELD_NUMBER: _ClassVar[int] + SHORT_CHANNEL_ID_FIELD_NUMBER: _ClassVar[int] + BASE_FEE_FIELD_NUMBER: _ClassVar[int] + PPM_FEE_FIELD_NUMBER: _ClassVar[int] + CLTV_EXPIRY_DELTA_FIELD_NUMBER: _ClassVar[int] + public_key: str + short_channel_id: str + base_fee: int + ppm_fee: int + cltv_expiry_delta: int + def __init__( + self, + public_key: _Optional[str] = ..., + short_channel_id: _Optional[str] = ..., + base_fee: _Optional[int] = ..., + ppm_fee: _Optional[int] = ..., + cltv_expiry_delta: _Optional[int] = ..., + ) -> None: ... -class CancelRequest(_message.Message): - __slots__ = ["payment_hash"] - PAYMENT_HASH_FIELD_NUMBER: _ClassVar[int] - payment_hash: str - def __init__(self, payment_hash: _Optional[str] = ...) -> None: ... +class RoutingHint(_message.Message): + __slots__ = ["hops"] + HOPS_FIELD_NUMBER: _ClassVar[int] + hops: _containers.RepeatedCompositeFieldContainer[Hop] + def __init__( + self, hops: _Optional[_Iterable[_Union[Hop, _Mapping]]] = ... + ) -> None: ... -class CancelResponse(_message.Message): - __slots__ = [] - def __init__(self) -> None: ... +class RoutingHintsResponse(_message.Message): + __slots__ = ["hints"] + HINTS_FIELD_NUMBER: _ClassVar[int] + hints: _containers.RepeatedCompositeFieldContainer[RoutingHint] + def __init__( + self, hints: _Optional[_Iterable[_Union[RoutingHint, _Mapping]]] = ... + ) -> None: ... class ListRequest(_message.Message): __slots__ = ["payment_hash"] @@ -109,6 +142,26 @@ class ListResponse(_message.Message): self, invoices: _Optional[_Iterable[_Union[Invoice, _Mapping]]] = ... ) -> None: ... +class SettleRequest(_message.Message): + __slots__ = ["payment_preimage"] + PAYMENT_PREIMAGE_FIELD_NUMBER: _ClassVar[int] + payment_preimage: str + def __init__(self, payment_preimage: _Optional[str] = ...) -> None: ... + +class SettleResponse(_message.Message): + __slots__ = [] + def __init__(self) -> None: ... + +class CancelRequest(_message.Message): + __slots__ = ["payment_hash"] + PAYMENT_HASH_FIELD_NUMBER: _ClassVar[int] + payment_hash: str + def __init__(self, payment_hash: _Optional[str] = ...) -> None: ... + +class CancelResponse(_message.Message): + __slots__ = [] + def __init__(self) -> None: ... + class TrackRequest(_message.Message): __slots__ = ["payment_hash"] PAYMENT_HASH_FIELD_NUMBER: _ClassVar[int] diff --git a/tools/hold/protos/hold_pb2_grpc.py b/tools/hold/protos/hold_pb2_grpc.py index b7bf4b5d..ab8c6ab6 100644 --- a/tools/hold/protos/hold_pb2_grpc.py +++ b/tools/hold/protos/hold_pb2_grpc.py @@ -19,6 +19,16 @@ def __init__(self, channel): request_serializer=hold__pb2.InvoiceRequest.SerializeToString, response_deserializer=hold__pb2.InvoiceResponse.FromString, ) + self.RoutingHints = channel.unary_unary( + "/hold.Hold/RoutingHints", + request_serializer=hold__pb2.RoutingHintsRequest.SerializeToString, + response_deserializer=hold__pb2.RoutingHintsResponse.FromString, + ) + self.List = channel.unary_unary( + "/hold.Hold/List", + request_serializer=hold__pb2.ListRequest.SerializeToString, + response_deserializer=hold__pb2.ListResponse.FromString, + ) self.Settle = channel.unary_unary( "/hold.Hold/Settle", request_serializer=hold__pb2.SettleRequest.SerializeToString, @@ -29,11 +39,6 @@ def __init__(self, channel): request_serializer=hold__pb2.CancelRequest.SerializeToString, response_deserializer=hold__pb2.CancelResponse.FromString, ) - self.List = channel.unary_unary( - "/hold.Hold/List", - request_serializer=hold__pb2.ListRequest.SerializeToString, - response_deserializer=hold__pb2.ListResponse.FromString, - ) self.Track = channel.unary_stream( "/hold.Hold/Track", request_serializer=hold__pb2.TrackRequest.SerializeToString, @@ -55,19 +60,25 @@ def Invoice(self, request, context): context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") - def Settle(self, request, context): + def RoutingHints(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") - def Cancel(self, request, context): + def List(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") - def List(self, request, context): + def Settle(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def Cancel(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details("Method not implemented!") @@ -93,6 +104,16 @@ def add_HoldServicer_to_server(servicer, server): request_deserializer=hold__pb2.InvoiceRequest.FromString, response_serializer=hold__pb2.InvoiceResponse.SerializeToString, ), + "RoutingHints": grpc.unary_unary_rpc_method_handler( + servicer.RoutingHints, + request_deserializer=hold__pb2.RoutingHintsRequest.FromString, + response_serializer=hold__pb2.RoutingHintsResponse.SerializeToString, + ), + "List": grpc.unary_unary_rpc_method_handler( + servicer.List, + request_deserializer=hold__pb2.ListRequest.FromString, + response_serializer=hold__pb2.ListResponse.SerializeToString, + ), "Settle": grpc.unary_unary_rpc_method_handler( servicer.Settle, request_deserializer=hold__pb2.SettleRequest.FromString, @@ -103,11 +124,6 @@ def add_HoldServicer_to_server(servicer, server): request_deserializer=hold__pb2.CancelRequest.FromString, response_serializer=hold__pb2.CancelResponse.SerializeToString, ), - "List": grpc.unary_unary_rpc_method_handler( - servicer.List, - request_deserializer=hold__pb2.ListRequest.FromString, - response_serializer=hold__pb2.ListResponse.SerializeToString, - ), "Track": grpc.unary_stream_rpc_method_handler( servicer.Track, request_deserializer=hold__pb2.TrackRequest.FromString, @@ -159,7 +175,7 @@ def Invoice( ) @staticmethod - def Settle( + def RoutingHints( request, target, options=(), @@ -174,9 +190,9 @@ def Settle( return grpc.experimental.unary_unary( request, target, - "/hold.Hold/Settle", - hold__pb2.SettleRequest.SerializeToString, - hold__pb2.SettleResponse.FromString, + "/hold.Hold/RoutingHints", + hold__pb2.RoutingHintsRequest.SerializeToString, + hold__pb2.RoutingHintsResponse.FromString, options, channel_credentials, insecure, @@ -188,7 +204,7 @@ def Settle( ) @staticmethod - def Cancel( + def List( request, target, options=(), @@ -203,9 +219,9 @@ def Cancel( return grpc.experimental.unary_unary( request, target, - "/hold.Hold/Cancel", - hold__pb2.CancelRequest.SerializeToString, - hold__pb2.CancelResponse.FromString, + "/hold.Hold/List", + hold__pb2.ListRequest.SerializeToString, + hold__pb2.ListResponse.FromString, options, channel_credentials, insecure, @@ -217,7 +233,7 @@ def Cancel( ) @staticmethod - def List( + def Settle( request, target, options=(), @@ -232,9 +248,38 @@ def List( return grpc.experimental.unary_unary( request, target, - "/hold.Hold/List", - hold__pb2.ListRequest.SerializeToString, - hold__pb2.ListResponse.FromString, + "/hold.Hold/Settle", + hold__pb2.SettleRequest.SerializeToString, + hold__pb2.SettleResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def Cancel( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/hold.Hold/Cancel", + hold__pb2.CancelRequest.SerializeToString, + hold__pb2.CancelResponse.FromString, options, channel_credentials, insecure, diff --git a/tools/hold/route_hints.py b/tools/hold/route_hints.py new file mode 100644 index 00000000..5318c6c0 --- /dev/null +++ b/tools/hold/route_hints.py @@ -0,0 +1,31 @@ +from bolt11.models.routehint import Route, RouteHint +from pyln.client import Plugin + + +class RouteHints: + _id: str + + def __init__(self, plugin: Plugin) -> None: + self._plugin = plugin + + def init(self) -> None: + self._id = self._plugin.rpc.getinfo()["id"] + + def get_private_channels(self, node: str) -> list[RouteHint]: + chans = self._plugin.rpc.listchannels(destination=self._id)["channels"] + return [ + RouteHint( + [ + Route( + public_key=chan["source"], + short_channel_id=chan["short_channel_id"], + base_fee=chan["base_fee_millisatoshi"], + ppm_fee=chan["fee_per_millionth"], + cltv_expiry_delta=chan["delay"], + ) + ] + ) + for chan in filter( + lambda chan: not chan["public"] and chan["source"] == node, chans + ) + ] diff --git a/tools/hold/server.py b/tools/hold/server.py index 64a15aed..09100285 100644 --- a/tools/hold/server.py +++ b/tools/hold/server.py @@ -1,14 +1,12 @@ import threading +from collections.abc import Callable, Iterable from concurrent import futures from queue import Empty -from typing import Callable, Iterable, TypeVar +from typing import TypeVar import grpc from encoder import Defaults from enums import invoice_state_final - -# noinspection PyProtectedMember -from grpc._server import _Server from grpc_interceptor import ServerInterceptor from protos.hold_pb2 import ( CancelRequest, @@ -17,6 +15,8 @@ InvoiceResponse, ListRequest, ListResponse, + RoutingHintsRequest, + RoutingHintsResponse, SettleRequest, SettleResponse, TrackAllRequest, @@ -64,9 +64,29 @@ def Invoice( # noqa: N802 optional_default( request.min_final_cltv_expiry, 0, Defaults.MinFinalCltvExpiry ), + Transformers.routing_hints_from_grpc(list(request.routing_hints)), ) ) + def RoutingHints( # noqa: N802 + self, + request: RoutingHintsRequest, + context: grpc.ServicerContext, # noqa: ARG002 + ) -> RoutingHintsResponse: + return Transformers.routing_hints_to_grpc( + self._hold.get_private_channels(request.node) + ) + + def List( # noqa: N802 + self, request: ListRequest, context: grpc.ServicerContext # noqa: ARG002 + ) -> ListResponse: + return ListResponse( + invoices=[ + Transformers.invoice_to_grpc(inv) + for inv in self._hold.list_invoices(request.payment_hash) + ] + ) + def Settle( # noqa: N802 self, request: SettleRequest, context: grpc.ServicerContext # noqa: ARG002 ) -> SettleResponse: @@ -79,16 +99,6 @@ def Cancel( # noqa: N802 self._hold.cancel(request.payment_hash) return CancelResponse() - def List( # noqa: N802 - self, request: ListRequest, context: grpc.ServicerContext # noqa: ARG002 - ) -> ListResponse: - return ListResponse( - invoices=[ - Transformers.invoice_to_grpc(inv) - for inv in self._hold.list_invoices(request.payment_hash) - ] - ) - def Track( # noqa: N802 self, request: TrackRequest, context: grpc.ServicerContext ) -> Iterable[TrackResponse]: @@ -162,7 +172,7 @@ class Server: _hold: Hold _plugin: Plugin - _server: _Server | None + _server: grpc.Server | None _server_thread: threading.Thread diff --git a/tools/hold/test_encoder.py b/tools/hold/test_encoder.py index b9a51aa1..cecbab5e 100644 --- a/tools/hold/test_encoder.py +++ b/tools/hold/test_encoder.py @@ -1,19 +1,22 @@ import random import pytest +from bolt11.models.routehint import Route, RouteHint from encoder import Defaults, Encoder, get_network_prefix, get_payment_secret -from test_utils import cln_con +from test_utils import RpcPlugin, cln_con from utils import time_now - -class RpcCaller: - @staticmethod - def getinfo() -> dict: - return cln_con("getinfo") - - -class RpcPlugin: - rpc = RpcCaller() +route_hint = RouteHint( + routes=[ + Route( + public_key="02e425026d928083eb432886c4c209abff4aea1e6bafca208671fdb0e42be4b63d", + short_channel_id="117x1x0", + base_fee=1000, + ppm_fee=1, + cltv_expiry_delta=80, + ) + ] +) class TestEncoder: @@ -88,6 +91,7 @@ def test_encode_min_final_cltv_expiry(self, cltv: int) -> None: ) dec = cln_con("decode", invoice) + assert dec["valid"] assert dec["min_final_cltv_expiry"] == cltv def test_encode_payment_secret(self) -> None: @@ -101,4 +105,44 @@ def test_encode_payment_secret(self) -> None: ) dec = cln_con("decode", invoice) + + assert dec["valid"] assert dec["payment_secret"] == payment_secret + + def test_encode_route_hints(self) -> None: + invoice = self.en.encode( + random.randbytes(32).hex(), + 10_000, + "memo", + route_hints=[route_hint], + ) + + dec = cln_con("decode", invoice) + assert dec["valid"] + + routes = dec["routes"] + assert len(routes) == 1 + + route = routes[0] + assert len(route) == 1 + + hop = route[0] + assert hop["pubkey"] == route_hint.routes[0].public_key + assert hop["fee_base_msat"] == route_hint.routes[0].base_fee + assert hop["short_channel_id"] == route_hint.routes[0].short_channel_id + assert hop["fee_proportional_millionths"] == route_hint.routes[0].ppm_fee + + def test_encode_route_hints_multiple(self) -> None: + invoice = self.en.encode( + random.randbytes(32).hex(), + 10_000, + "memo", + route_hints=[route_hint, route_hint], + ) + + dec = cln_con("decode", invoice) + assert dec["valid"] + + routes = dec["routes"] + assert len(routes) == 2 + assert routes[0] == routes[1] diff --git a/tools/hold/test_grpc.py b/tools/hold/test_grpc.py index 970acf4a..6898c8ad 100644 --- a/tools/hold/test_grpc.py +++ b/tools/hold/test_grpc.py @@ -19,6 +19,8 @@ InvoiceState, InvoiceUnpaid, ListRequest, + RoutingHintsRequest, + RoutingHintsResponse, SettleRequest, TrackAllRequest, TrackRequest, @@ -29,6 +31,8 @@ LndPay, cln_con, connect_peers, + get_channel_info, + lnd, start_plugin, stop_plugin, ) @@ -92,7 +96,7 @@ def test_invoice_defaults(self, cl: HoldStub) -> None: assert dec["min_final_cltv_expiry"] == Defaults.MinFinalCltvExpiry @pytest.mark.parametrize("description", ["some", "text", "Send to BTC address"]) - def test_add_description(self, cl: HoldStub, description: str) -> None: + def test_invoice_description(self, cl: HoldStub, description: str) -> None: invoice = cl.Invoice( InvoiceRequest( payment_hash=random.randbytes(32).hex(), @@ -106,7 +110,7 @@ def test_add_description(self, cl: HoldStub, description: str) -> None: assert dec["description"] == description @pytest.mark.parametrize("expiry", [1, 2, 3, 3600, 24000, 86400]) - def test_add_expiry(self, cl: HoldStub, expiry: int) -> None: + def test_invoice_expiry(self, cl: HoldStub, expiry: int) -> None: invoice = cl.Invoice( InvoiceRequest( payment_hash=random.randbytes(32).hex(), @@ -120,7 +124,7 @@ def test_add_expiry(self, cl: HoldStub, expiry: int) -> None: assert dec["expiry"] == expiry @pytest.mark.parametrize("min_final_cltv_expiry", [1, 2, 3, 80, 144, 150, 200]) - def test_add_min_final_cltv_expiry( + def test_invoice_min_final_cltv_expiry( self, cl: HoldStub, min_final_cltv_expiry: int, @@ -137,6 +141,83 @@ def test_add_min_final_cltv_expiry( assert dec["valid"] assert dec["min_final_cltv_expiry"] == min_final_cltv_expiry + def test_invoice_routing_hints(self, cl: HoldStub) -> None: + lnd_pubkey = lnd(LndNode.One, "getinfo")["identity_pubkey"] + routing_hints: RoutingHintsResponse = cl.RoutingHints( + RoutingHintsRequest(node=lnd_pubkey) + ) + + invoice = cl.Invoice( + InvoiceRequest( + payment_hash=random.randbytes(32).hex(), + amount_msat=10_000, + routing_hints=routing_hints.hints, + ) + ).bolt11 + + dec = cln_con("decode", invoice) + assert dec["valid"] + assert len(dec["routes"]) == 1 + assert len(dec["routes"][0]) == 1 + + hop = dec["routes"][0][0] + hint = routing_hints.hints[0].hops[0] + + assert hop["pubkey"] == hint.public_key + assert hop["fee_base_msat"] == hint.base_fee + assert hop["short_channel_id"] == hint.short_channel_id + assert hop["fee_proportional_millionths"] == hint.ppm_fee + + def test_invoice_routing_hints_multiple(self, cl: HoldStub) -> None: + lnd_pubkey = lnd(LndNode.One, "getinfo")["identity_pubkey"] + routing_hints: RoutingHintsResponse = cl.RoutingHints( + RoutingHintsRequest(node=lnd_pubkey) + ) + + routing_hints.hints.append(routing_hints.hints[0]) + + invoice = cl.Invoice( + InvoiceRequest( + payment_hash=random.randbytes(32).hex(), + amount_msat=10_000, + routing_hints=routing_hints.hints, + ) + ).bolt11 + + dec = cln_con("decode", invoice) + assert dec["valid"] + assert len(dec["routes"]) == 2 + assert len(dec["routes"][0]) == 1 + assert len(dec["routes"][1]) == 1 + + assert dec["routes"][0] == dec["routes"][1] + + def test_routing_hints(self, cl: HoldStub) -> None: + lnd_pubkey = lnd(LndNode.One, "getinfo")["identity_pubkey"] + + res: RoutingHintsResponse = cl.RoutingHints( + RoutingHintsRequest(node=lnd_pubkey) + ) + assert len(res.hints) == 1 + + hops = res.hints[0].hops + assert len(hops) == 1 + + hop = hops[0] + + channel_info = get_channel_info(lnd_pubkey, hop.short_channel_id) + + assert hop.cltv_expiry_delta == channel_info["delay"] + assert hop.ppm_fee == channel_info["fee_per_millionth"] + assert hop.base_fee == channel_info["base_fee_millisatoshi"] + assert hop.short_channel_id == channel_info["short_channel_id"] + + def test_routing_hints_none_found(self, cl: HoldStub) -> None: + res: RoutingHintsResponse = cl.RoutingHints( + RoutingHintsRequest(node="not found") + ) + assert len(res.hints) == 0 + def test_list(self, cl: HoldStub) -> None: invoices = cl.List(ListRequest()).invoices assert len(invoices) > 1 diff --git a/tools/hold/test_plugin.py b/tools/hold/test_plugin.py index 1a1df934..6335e256 100644 --- a/tools/hold/test_plugin.py +++ b/tools/hold/test_plugin.py @@ -15,6 +15,7 @@ cln_con, connect_peers, format_json, + get_channel_info, lnd, start_plugin, stop_plugin, @@ -120,6 +121,52 @@ def test_add_min_final_cltv_expiry( assert dec["valid"] assert dec["min_final_cltv_expiry"] == min_final_cltv_expiry + def test_invoice_routing_hints(self, cln: CliCaller) -> None: + lnd_pubkey = lnd(LndNode.One, "getinfo")["identity_pubkey"] + hints = cln("routinghints", lnd_pubkey)["hints"] + + invoice = cln( + "-k", + "holdinvoice", + f"payment_hash={random.randbytes(32).hex()}", + "amount_msat=10000", + f"routing_hints={format_json(hints)}", + )["bolt11"] + + dec = cln_con("decode", invoice) + assert dec["valid"] + assert len(dec["routes"]) == 1 + assert len(dec["routes"][0]) == 1 + + hop = dec["routes"][0][0] + hint = hints[0]["routes"][0] + + assert hop["pubkey"] == hint["public_key"] + assert hop["fee_base_msat"] == hint["base_fee"] + assert hop["short_channel_id"] == hint["short_channel_id"] + assert hop["fee_proportional_millionths"] == hint["ppm_fee"] + + def test_invoice_routing_hints_multiple(self, cln: CliCaller) -> None: + lnd_pubkey = lnd(LndNode.One, "getinfo")["identity_pubkey"] + hints = cln("routinghints", lnd_pubkey)["hints"] + hints.append(hints[0]) + + invoice = cln( + "-k", + "holdinvoice", + f"payment_hash={random.randbytes(32).hex()}", + "amount_msat=10000", + f"routing_hints={format_json(hints)}", + )["bolt11"] + + dec = cln_con("decode", invoice) + assert dec["valid"] + assert len(dec["routes"]) == 2 + assert len(dec["routes"][0]) == 1 + assert len(dec["routes"][1]) == 1 + + assert dec["routes"][0] == dec["routes"][1] + def test_add_duplicate_fail(self, cln: CliCaller) -> None: amount = 10000 payment_hash = random.randbytes(32).hex() @@ -457,6 +504,29 @@ def test_htlc_payment_secret_missing(self, cln: CliCaller) -> None: == "unpaid" ) + def test_routinghints(self, cln: CliCaller) -> None: + lnd_pubkey = lnd(LndNode.One, "getinfo")["identity_pubkey"] + hints = cln("routinghints", lnd_pubkey)["hints"] + assert len(hints) == 1 + + routes = hints[0] + assert len(routes) == 1 + assert len(routes["routes"]) == 1 + + route = routes["routes"][0] + + channel_info = get_channel_info(lnd_pubkey, route["short_channel_id"]) + + assert route["cltv_expiry_delta"] == channel_info["delay"] + assert route["ppm_fee"] == channel_info["fee_per_millionth"] + assert route["base_fee"] == channel_info["base_fee_millisatoshi"] + assert route["short_channel_id"] == channel_info["short_channel_id"] + + def test_routinghints_none_found(self, cln: CliCaller) -> None: + res = cln("routinghints", "none") + assert "hints" in res + assert len(res["hints"]) == 0 + def test_wipe_single(self, cln: CliCaller) -> None: _, payment_hash, _ = add_hold_invoice(cln) res = cln("dev-wipeholdinvoices", payment_hash) @@ -491,14 +561,14 @@ def test_ignore_non_hold(self, cln: CliCaller) -> None: def test_ignore_forward(self, cln: CliCaller) -> None: cln_id = cln("getinfo")["id"] - channels = lnd(LndNode.Two, "listchannels")["channels"] + channels = lnd(LndNode.One, "listchannels")["channels"] cln_channel = next(c for c in channels if c["remote_pubkey"] == cln_id)[ "chan_id" ] - invoice = lnd(LndNode.One, "addinvoice", "10000")["payment_request"] + invoice = lnd(LndNode.Two, "addinvoice", "10000")["payment_request"] - pay = LndPay(LndNode.Two, invoice, outgoing_chan_id=cln_channel) + pay = LndPay(LndNode.One, invoice, outgoing_chan_id=cln_channel) pay.start() pay.join() diff --git a/tools/hold/test_route_hints.py b/tools/hold/test_route_hints.py new file mode 100644 index 00000000..8a18a394 --- /dev/null +++ b/tools/hold/test_route_hints.py @@ -0,0 +1,35 @@ +from route_hints import RouteHints +from test_utils import LndNode, RpcPlugin, cln_con, get_channel_info, lnd + + +class TestRouteHints: + # noinspection PyTypeChecker + rh = RouteHints(RpcPlugin()) + + def test_init(self) -> None: + self.rh.init() + + assert self.rh._plugin is not None # noqa: SLF001 + assert self.rh._id == cln_con("getinfo")["id"] # noqa: SLF001 + + def test_get_private_channels(self) -> None: + other_pubkey = lnd(LndNode.One, "getinfo")["identity_pubkey"] + + hints = self.rh.get_private_channels(other_pubkey) + assert len(hints) == 1 + + hint = hints[0] + assert len(hint.routes) == 1 + + route = hint.routes[0] + assert route.public_key == other_pubkey + + channel_info = get_channel_info(other_pubkey, route.short_channel_id) + + assert route.cltv_expiry_delta == channel_info["delay"] + assert route.ppm_fee == channel_info["fee_per_millionth"] + assert route.base_fee == channel_info["base_fee_millisatoshi"] + assert route.short_channel_id == channel_info["short_channel_id"] + + def test_get_private_channels_none_found(self) -> None: + assert self.rh.get_private_channels("not found") == [] diff --git a/tools/hold/test_utils.py b/tools/hold/test_utils.py index a4560119..878b8a8a 100644 --- a/tools/hold/test_utils.py +++ b/tools/hold/test_utils.py @@ -5,11 +5,29 @@ from threading import Thread from typing import Any -PLUGIN_PATH = "/root/hold-start.sh" +PLUGIN_PATH = "/root/hold.sh" CliCaller = Callable[..., dict[str, Any]] +class RpcCaller: + @staticmethod + def getinfo() -> dict: + return cln_con("getinfo") + + @staticmethod + def listchannels(**kwargs: dict[str, str]) -> dict: + args = "listchannels -k" + for key, val in kwargs.items(): + args += f" {key}={val}" + + return cln_con(args) + + +class RpcPlugin: + rpc = RpcCaller() + + class LndNode(Enum): One = 1 Two = 2 @@ -101,3 +119,10 @@ def cln_con(*args: str) -> dict[str, Any]: f"docker exec regtest lightning-cli {' '.join(args)}", ) ) + + +def get_channel_info(node: str, short_chan_id: str | int) -> dict[str, Any]: + channel_infos = cln_con("listchannels", "-k", f"short_channel_id={short_chan_id}")[ + "channels" + ] + return channel_infos[0] if channel_infos[0]["source"] == node else channel_infos[1] diff --git a/tools/hold/transformers.py b/tools/hold/transformers.py index 8bff244c..89f782dd 100644 --- a/tools/hold/transformers.py +++ b/tools/hold/transformers.py @@ -1,10 +1,16 @@ +from typing import Any + +from bolt11.models.routehint import Route, RouteHint from invoice import HoldInvoice, InvoiceState from protos.hold_pb2 import ( + Hop, Invoice, InvoiceAccepted, InvoiceCancelled, InvoicePaid, InvoiceUnpaid, + RoutingHint, + RoutingHintsResponse, ) INVOICE_STATE_TO_GRPC = { @@ -24,3 +30,78 @@ def invoice_to_grpc(invoice: HoldInvoice) -> Invoice: state=INVOICE_STATE_TO_GRPC[invoice.state], bolt11=invoice.bolt11, ) + + @staticmethod + def routing_hints_to_grpc(hints: list[RouteHint]) -> RoutingHintsResponse: + return RoutingHintsResponse( + hints=[Transformers.routing_hint_to_grpc(hint) for hint in hints] + ) + + @staticmethod + def routing_hint_to_grpc(hint: RouteHint) -> RoutingHint: + return RoutingHint(hops=[Transformers.hop_to_grpc(hop) for hop in hint.routes]) + + @staticmethod + def hop_to_grpc(hop: Route) -> Hop: + return Hop( + public_key=hop.public_key, + short_channel_id=hop.short_channel_id, + base_fee=hop.base_fee, + ppm_fee=hop.ppm_fee, + cltv_expiry_delta=hop.cltv_expiry_delta, + ) + + @staticmethod + def routing_hints_from_grpc(hints: list[RoutingHint]) -> list[RouteHint]: + return [Transformers.routing_hint_from_grpc(hint) for hint in hints] + + @staticmethod + def routing_hint_from_grpc(hint: RoutingHint) -> RouteHint: + return RouteHint(routes=[Transformers.hop_from_grpc(hop) for hop in hint.hops]) + + @staticmethod + def hop_from_grpc(hop: Hop) -> Route: + return Route( + public_key=hop.public_key, + short_channel_id=hop.short_channel_id, + base_fee=hop.base_fee, + ppm_fee=hop.ppm_fee, + cltv_expiry_delta=hop.cltv_expiry_delta, + ) + + @staticmethod + def routing_hints_from_json(hints: list[Any]) -> list[RouteHint]: + return [Transformers.routing_hint_from_json(hint) for hint in hints] + + @staticmethod + def routing_hint_from_json(hint: dict[str, Any]) -> RouteHint: + return RouteHint( + routes=[Transformers.hop_from_json(hop) for hop in hint["routes"]] + ) + + @staticmethod + def hop_from_json(hop: dict[str, str | int]) -> Route: + return Route( + public_key=hop["public_key"], + short_channel_id=hop["short_channel_id"], + base_fee=hop["base_fee"], + ppm_fee=hop["ppm_fee"], + cltv_expiry_delta=hop["cltv_expiry_delta"], + ) + + @staticmethod + def named_tuples_to_dict(val: object) -> object: + if isinstance(val, list): + return [Transformers.named_tuples_to_dict(entry) for entry in val] + + if isinstance(val, tuple) and hasattr(val, "_asdict"): + # noinspection PyProtectedMember + return Transformers.named_tuples_to_dict(val._asdict()) + + if isinstance(val, dict): + return { + key: Transformers.named_tuples_to_dict(value) + for key, value in val.items() + } + + return val