From 8f6ba65fd31a369873d2a9cb42ca30be3c514515 Mon Sep 17 00:00:00 2001 From: Camillarhi Date: Tue, 6 Jan 2026 00:28:21 +0100 Subject: [PATCH] receive payjoin payments --- Cargo.toml | 2 + src/chain/bitcoind.rs | 91 +++ src/chain/electrum.rs | 128 +++- src/chain/esplora.rs | 7 + src/chain/mod.rs | 34 +- src/config.rs | 12 +- src/error.rs | 3 + src/io/mod.rs | 4 + src/payment/mod.rs | 1 + src/payment/payjoin_payment/manager.rs | 602 ++++++++++++++++++ src/payment/payjoin_payment/mod.rs | 10 + .../payjoin_payment/payjoin_session.rs | 249 ++++++++ src/payment/payjoin_payment/persist.rs | 227 +++++++ src/types.rs | 3 + src/wallet/mod.rs | 92 +++ 15 files changed, 1462 insertions(+), 3 deletions(-) create mode 100644 src/payment/payjoin_payment/manager.rs create mode 100644 src/payment/payjoin_payment/mod.rs create mode 100644 src/payment/payjoin_payment/payjoin_session.rs create mode 100644 src/payment/payjoin_payment/persist.rs diff --git a/Cargo.toml b/Cargo.toml index f99b164a2..a344c83ce 100755 --- a/Cargo.toml +++ b/Cargo.toml @@ -80,6 +80,8 @@ prost = { version = "0.11.6", default-features = false} #bitcoin-payment-instructions = { version = "0.6" } bitcoin-payment-instructions = { git = "https://github.com/tnull/bitcoin-payment-instructions", rev = "6796e87525d6c564e1332354a808730e2ba2ebf8" } +payjoin = { git = "https://github.com/payjoin/rust-payjoin.git", package = "payjoin", default-features = false, features = ["v2", "io"] } + [target.'cfg(windows)'.dependencies] winapi = { version = "0.3", features = ["winbase"] } diff --git a/src/chain/bitcoind.rs b/src/chain/bitcoind.rs index 1c8cf16ba..d40a9a54e 100644 --- a/src/chain/bitcoind.rs +++ b/src/chain/bitcoind.rs @@ -618,6 +618,57 @@ impl BitcoindChainSource { } } } + + pub(crate) async fn can_broadcast_transaction(&self, tx: &Transaction) -> Result { + let timeout_fut = tokio::time::timeout( + Duration::from_secs(TX_BROADCAST_TIMEOUT_SECS), + self.api_client.test_mempool_accept(tx), + ); + + match timeout_fut.await { + Ok(res) => res.map_err(|e| { + log_error!( + self.logger, + "Failed to test mempool accept for transaction {}: {}", + tx.compute_txid(), + e + ); + Error::TxBroadcastFailed + }), + Err(e) => { + log_error!( + self.logger, + "Failed to test mempool accept for transaction {} due to timeout: {}", + tx.compute_txid(), + e + ); + log_trace!( + self.logger, + "Failed test mempool accept transaction bytes: {}", + log_bytes!(tx.encode()) + ); + Err(Error::TxBroadcastFailed) + }, + } + } + + pub(crate) async fn get_transaction(&self, txid: &Txid) -> Result, Error> { + let timeout_fut = tokio::time::timeout( + Duration::from_secs(TX_BROADCAST_TIMEOUT_SECS), + self.api_client.get_raw_transaction(txid), + ); + + match timeout_fut.await { + Ok(res) => res.map_err(|e| { + log_error!(self.logger, "Failed to get transaction {}: {}", txid, e); + Error::TxSyncFailed + }), + Err(e) => { + log_error!(self.logger, "Failed to get transaction {} due to timeout: {}", txid, e); + Err(Error::TxSyncTimeout) + }, + } + } } #[derive(Clone)] @@ -1235,6 +1286,46 @@ impl BitcoindClient { .collect(); Ok(evicted_txids) } + + /// Tests whether the provided transaction would be accepted by the mempool. + pub(crate) async fn test_mempool_accept(&self, tx: &Transaction) -> std::io::Result { + match self { + BitcoindClient::Rpc { rpc_client, .. } => { + Self::test_mempool_accept_inner(Arc::clone(rpc_client), tx).await + }, + BitcoindClient::Rest { rpc_client, .. } => { + // We rely on the internal RPC client to make this call, as this + // operation is not supported by Bitcoin Core's REST interface. + Self::test_mempool_accept_inner(Arc::clone(rpc_client), tx).await + }, + } + } + + async fn test_mempool_accept_inner( + rpc_client: Arc, tx: &Transaction, + ) -> std::io::Result { + let tx_serialized = bitcoin::consensus::encode::serialize_hex(tx); + let tx_array = serde_json::json!([tx_serialized]); + + let resp = + rpc_client.call_method::("testmempoolaccept", &[tx_array]).await?; + + if let Some(array) = resp.as_array() { + if let Some(first_result) = array.first() { + Ok(first_result.get("allowed").and_then(|v| v.as_bool()).unwrap_or(false)) + } else { + Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Empty array response from testmempoolaccept", + )) + } + } else { + Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "testmempoolaccept did not return an array", + )) + } + } } impl BlockSource for BitcoindClient { diff --git a/src/chain/electrum.rs b/src/chain/electrum.rs index 21e66f3a6..8bf539149 100644 --- a/src/chain/electrum.rs +++ b/src/chain/electrum.rs @@ -15,7 +15,7 @@ use bdk_chain::bdk_core::spk_client::{ }; use bdk_electrum::BdkElectrumClient; use bdk_wallet::{KeychainKind as BdkKeyChainKind, Update as BdkUpdate}; -use bitcoin::{FeeRate, Network, Script, ScriptBuf, Transaction, Txid}; +use bitcoin::{FeeRate, Network, OutPoint, Script, ScriptBuf, Transaction, Txid}; use electrum_client::{ Batch, Client as ElectrumClient, ConfigBuilder as ElectrumConfigBuilder, ElectrumApi, }; @@ -291,6 +291,21 @@ impl ElectrumChainSource { electrum_client.broadcast(tx).await; } } + + pub(crate) async fn get_transaction(&self, txid: &Txid) -> Result, Error> { + let electrum_client: Arc = + if let Some(client) = self.electrum_runtime_status.read().unwrap().client().as_ref() { + Arc::clone(client) + } else { + debug_assert!( + false, + "We should have started the chain source before getting transactions" + ); + return Err(Error::TxSyncFailed); + }; + + electrum_client.get_transaction(txid).await + } } impl Filter for ElectrumChainSource { @@ -632,6 +647,117 @@ impl ElectrumRuntimeClient { Ok(new_fee_rate_cache) } + + async fn get_transaction(&self, txid: &Txid) -> Result, Error> { + let electrum_client = Arc::clone(&self.electrum_client); + let txid_copy = *txid; + + let spawn_fut = + self.runtime.spawn_blocking(move || electrum_client.transaction_get(&txid_copy)); + let timeout_fut = + tokio::time::timeout(Duration::from_secs(TX_BROADCAST_TIMEOUT_SECS), spawn_fut); + + match timeout_fut.await { + Ok(res) => match res { + Ok(inner_res) => match inner_res { + Ok(tx) => Ok(Some(tx)), + Err(e) => { + // Check if it's a "not found" error + let error_str = e.to_string(); + if error_str.contains("No such mempool or blockchain transaction") + || error_str.contains("not found") + { + Ok(None) + } else { + log_error!(self.logger, "Failed to get transaction {}: {}", txid, e); + Err(Error::TxSyncFailed) + } + }, + }, + Err(e) => { + log_error!(self.logger, "Failed to get transaction {}: {}", txid, e); + Err(Error::TxSyncFailed) + }, + }, + Err(e) => { + log_error!(self.logger, "Failed to get transaction {} due to timeout: {}", txid, e); + Err(Error::TxSyncTimeout) + }, + } + } + + async fn is_outpoint_spent(&self, outpoint: &OutPoint) -> Result { + // First get the transaction to find the scriptPubKey of the output + let tx = match self.get_transaction(&outpoint.txid).await? { + Some(tx) => tx, + None => { + // Transaction doesn't exist, so outpoint can't be spent + // (or never existed) + return Ok(false); + }, + }; + + // Check if the output index is valid + let vout = outpoint.vout as usize; + if vout >= tx.output.len() { + // Invalid output index + return Ok(false); + } + + let script_pubkey = &tx.output[vout].script_pubkey; + let electrum_client = Arc::clone(&self.electrum_client); + let script_pubkey_clone = script_pubkey.clone(); + let outpoint_txid = outpoint.txid; + let outpoint_vout = outpoint.vout; + + let spawn_fut = self + .runtime + .spawn_blocking(move || electrum_client.script_list_unspent(&script_pubkey_clone)); + let timeout_fut = + tokio::time::timeout(Duration::from_secs(TX_BROADCAST_TIMEOUT_SECS), spawn_fut); + + match timeout_fut.await { + Ok(res) => match res { + Ok(inner_res) => match inner_res { + Ok(unspent_list) => { + // Check if our outpoint is in the unspent list + let is_unspent = unspent_list.iter().any(|u| { + u.tx_hash == outpoint_txid && u.tx_pos == outpoint_vout as usize + }); + // Return true if spent (not in unspent list) + Ok(!is_unspent) + }, + Err(e) => { + log_error!( + self.logger, + "Failed to check if outpoint {} is spent: {}", + outpoint, + e + ); + Err(Error::TxSyncFailed) + }, + }, + Err(e) => { + log_error!( + self.logger, + "Failed to check if outpoint {} is spent: {}", + outpoint, + e + ); + Err(Error::TxSyncFailed) + }, + }, + Err(e) => { + log_error!( + self.logger, + "Failed to check if outpoint {} is spent due to timeout: {}", + outpoint, + e + ); + Err(Error::TxSyncTimeout) + }, + } + } } impl Filter for ElectrumRuntimeClient { diff --git a/src/chain/esplora.rs b/src/chain/esplora.rs index 8ab941888..ddf679610 100644 --- a/src/chain/esplora.rs +++ b/src/chain/esplora.rs @@ -417,6 +417,13 @@ impl EsploraChainSource { } } } + + pub(crate) async fn get_transaction(&self, txid: &Txid) -> Result, Error> { + self.esplora_client.get_tx(txid).await.map_err(|e| { + log_error!(self.logger, "Failed to get transaction {}: {}", txid, e); + Error::TxSyncFailed + }) + } } impl Filter for EsploraChainSource { diff --git a/src/chain/mod.rs b/src/chain/mod.rs index afd502363..dccab0e05 100644 --- a/src/chain/mod.rs +++ b/src/chain/mod.rs @@ -13,7 +13,7 @@ use std::collections::HashMap; use std::sync::{Arc, RwLock}; use std::time::Duration; -use bitcoin::{Script, Txid}; +use bitcoin::{Script, Transaction, Txid}; use lightning::chain::{BestBlock, Filter}; use crate::chain::bitcoind::{BitcoindChainSource, UtxoSourceClient}; @@ -459,6 +459,38 @@ impl ChainSource { } } } + + pub(crate) fn can_broadcast_transaction(&self, tx: &Transaction) -> Result { + tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(async { + match &self.kind { + ChainSourceKind::Bitcoind(bitcoind_chain_source) => { + bitcoind_chain_source.can_broadcast_transaction(tx).await + }, + ChainSourceKind::Esplora{..} => { + // Esplora doesn't support testmempoolaccept equivalent. + unreachable!("Mempool accept testing is not supported with Esplora backend. Use BitcoindRpc for this functionality.") + }, + ChainSourceKind::Electrum{..} => { + // Electrum doesn't support testmempoolaccept equivalent. + unreachable!("Mempool accept testing is not supported with Electrum backend. Use BitcoindRpc for this functionality.") + }, + } + }) + }) + } + + pub(crate) fn get_transaction(&self, txid: &Txid) -> Result, Error> { + tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(async { + match &self.kind { + ChainSourceKind::Bitcoind(bitcoind) => bitcoind.get_transaction(txid).await, + ChainSourceKind::Esplora(esplora) => esplora.get_transaction(txid).await, + ChainSourceKind::Electrum(electrum) => electrum.get_transaction(txid).await, + } + }) + }) + } } impl Filter for ChainSource { diff --git a/src/config.rs b/src/config.rs index 6c9d1640a..afe65e64b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -18,6 +18,7 @@ use lightning::routing::router::RouteParametersConfig; use lightning::util::config::{ ChannelConfig as LdkChannelConfig, MaxDustHTLCExposure as LdkMaxDustHTLCExposure, UserConfig, }; +use bitreq::URL; use crate::logger::LogLevel; @@ -127,7 +128,8 @@ pub(crate) const HRN_RESOLUTION_TIMEOUT_SECS: u64 = 5; /// | `probing_liquidity_limit_multiplier` | 3 | /// | `log_level` | Debug | /// | `anchor_channels_config` | Some(..) | -/// | `route_parameters` | None | +/// | `route_parameters` | None | +/// | `payjoin_config` | None | /// /// See [`AnchorChannelsConfig`] and [`RouteParametersConfig`] for more information regarding their /// respective default values. @@ -192,6 +194,7 @@ pub struct Config { /// **Note:** If unset, default parameters will be used, and you will be able to override the /// parameters on a per-payment basis in the corresponding method calls. pub route_parameters: Option, + pub payjoin_config: Option, } impl Default for Config { @@ -206,6 +209,7 @@ impl Default for Config { anchor_channels_config: Some(AnchorChannelsConfig::default()), route_parameters: None, node_alias: None, + payjoin_config: None, } } } @@ -561,6 +565,12 @@ pub enum AsyncPaymentsRole { Server, } +#[derive(Debug, Clone)] +pub struct PayjoinConfig { + pub payjoin_directory: URL, + pub ohttp_relay: URL, +} + #[cfg(test)] mod tests { use std::str::FromStr; diff --git a/src/error.rs b/src/error.rs index ea0bcca3b..e8c232864 100644 --- a/src/error.rs +++ b/src/error.rs @@ -131,6 +131,8 @@ pub enum Error { AsyncPaymentServicesDisabled, /// Parsing a Human-Readable Name has failed. HrnParsingFailed, + /// A transaction broadcast operation failed. + TxBroadcastFailed, } impl fmt::Display for Error { @@ -213,6 +215,7 @@ impl fmt::Display for Error { Self::HrnParsingFailed => { write!(f, "Failed to parse a human-readable name.") }, + Self::TxBroadcastFailed => write!(f, "Failed to broadcast transaction."), } } } diff --git a/src/io/mod.rs b/src/io/mod.rs index 7afd5bd40..99aad8205 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -78,3 +78,7 @@ pub(crate) const BDK_WALLET_INDEXER_KEY: &str = "indexer"; /// /// [`StaticInvoice`]: lightning::offers::static_invoice::StaticInvoice pub(crate) const STATIC_INVOICE_STORE_PRIMARY_NAMESPACE: &str = "static_invoices"; + +/// The payjoin sessions will be persisted under this key. +pub(crate) const PAYJOIN_SESSION_STORE_PRIMARY_NAMESPACE: &str = "payjoin_sessions"; +pub(crate) const PAYJOIN_SESSION_STORE_SECONDARY_NAMESPACE: &str = ""; diff --git a/src/payment/mod.rs b/src/payment/mod.rs index c82f35c8f..b1baf3933 100644 --- a/src/payment/mod.rs +++ b/src/payment/mod.rs @@ -11,6 +11,7 @@ pub(crate) mod asynchronous; mod bolt11; mod bolt12; mod onchain; +pub(crate) mod payjoin_payment; mod spontaneous; pub(crate) mod store; mod unified; diff --git a/src/payment/payjoin_payment/manager.rs b/src/payment/payjoin_payment/manager.rs new file mode 100644 index 000000000..ad9f7c402 --- /dev/null +++ b/src/payment/payjoin_payment/manager.rs @@ -0,0 +1,602 @@ +// This file is Copyright its original authors, visible in version control history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license , at your option. You may not use this file except in +// accordance with one or both of these licenses. + +use bitcoin::consensus::encode::serialize_hex; +use bitcoin::{Amount, FeeRate, TxIn, Weight}; +use lightning::ln::channelmanager::PaymentId; +use payjoin::persist::{AsyncSessionPersister, OptionalTransitionOutcome}; +use payjoin::receive::InputPair; +use payjoin::ImplementationError; + +use crate::chain::ChainSource; +use crate::config::Config; +use crate::fee_estimator::{ConfirmationTarget, FeeEstimator, OnchainFeeEstimator}; +use crate::logger::{log_debug, log_error, log_info, LdkLogger, Logger}; +use crate::payment::payjoin_payment::payjoin_session::{PayjoinDirection, PayjoinStatus}; +use crate::types::{Broadcaster, DynStore}; +use crate::Error; +use crate::{ + payment::payjoin_payment::persist::KVStorePayjoinReceiverPersister, types::PayjoinSessionStore, + wallet::Wallet, +}; +use payjoin::bitcoin::psbt::Input; +use payjoin::io::fetch_ohttp_keys; +use payjoin::receive::v2::{ + replay_event_log_async as replay_receiver_event_log_async, HasReplyableError, Initialized, + MaybeInputsOwned, MaybeInputsSeen, Monitor, OutputsUnknown, PayjoinProposal, + ProvisionalProposal, ReceiveSession, Receiver, ReceiverBuilder, UncheckedOriginalPayload, + WantsFeeRange, WantsInputs, WantsOutputs, +}; +use rand::RngCore; +use std::sync::Arc; + +#[derive(Clone)] +pub(crate) struct PayjoinManager { + payjoin_session_store: Arc, + kv_store: Arc, + logger: Arc, + config: Arc, + broadcaster: Arc, + wallet: Arc, + fee_estimator: Arc, + chain_source: Arc, + stop_receiver: tokio::sync::watch::Receiver<()>, +} + +// UPDATE ERROR TYPES!!!!!!!!! +// UPDATE ERROR TYPES!!!!!!!!! +// UPDATE ERROR TYPES!!!!!!!!! +// UPDATE ERROR TYPES!!!!!!!!! +// UPDATE ERROR TYPES!!!!!!!!! +// UPDATE ERROR TYPES!!!!!!!!! + +impl PayjoinManager { + pub(crate) fn new( + payjoin_session_store: Arc, kv_store: Arc, + logger: Arc, config: Arc, broadcaster: Arc, + wallet: Arc, fee_estimator: Arc, + chain_source: Arc, stop_receiver: tokio::sync::watch::Receiver<()>, + ) -> Self { + Self { + payjoin_session_store, + kv_store, + logger, + config, + broadcaster, + wallet, + fee_estimator, + chain_source, + stop_receiver, + } + } + + async fn receive_payjoin( + &self, amount: Amount, fee_rate: Option, + ) -> Result<(), Error> { + let payjoin_config = self.config.payjoin_config.as_ref().ok_or(Error::InvalidAddress)?; + + // Generate a new session ID + let mut random_bytes = [0u8; 32]; + rand::rng().fill_bytes(&mut random_bytes); + let session_id = PaymentId(random_bytes); + + // Create a new persister for this session + let persister = KVStorePayjoinReceiverPersister::new( + session_id, + self.kv_store.clone(), + self.logger.clone(), + )?; + + let address = self.wallet.get_new_address()?; + let ohttp_keys = fetch_ohttp_keys( + payjoin_config.ohttp_relay.clone().as_str(), + payjoin_config.payjoin_directory.clone().as_str(), + ) + .await + .map_err(|e| { + log_error!(self.logger, "Failed to fetch OHTTP keys: {}", e); + Error::InvalidAddress // or create a new error variant like Error::OhttpKeyFetchFailed + })?; + log_debug!(self.logger, "Fetched OHTTP keys: {:?}", ohttp_keys); + + let confirmation_target = ConfirmationTarget::OnchainPayment; + let fee_rate = + fee_rate.unwrap_or_else(|| self.fee_estimator.estimate_fee_rate(confirmation_target)); + + let session = ReceiverBuilder::new( + address, + payjoin_config.payjoin_directory.clone().as_str(), + ohttp_keys, + ) + .map_err(|e| { + log_error!(self.logger, "Failed to create receiver builder: {}", e); + Error::InvalidAddress // or another appropriate variant + })? + .with_amount(amount) + .with_max_fee_rate(fee_rate) + .build() + .save_async(&persister) + .await?; + + log_info!(self.logger, "Receive session established"); + let pj_uri = session.pj_uri(); + log_info!(self.logger, "Request Payjoin by sharing this Payjoin Uri: {}", pj_uri); + + self.process_receiver_session(ReceiveSession::Initialized(session.clone()), &persister) + .await?; + Ok(()) + } + + async fn process_receiver_session( + &self, session: ReceiveSession, persister: &KVStorePayjoinReceiverPersister, + ) -> Result<(), Error> { + let res = { + match session { + ReceiveSession::Initialized(proposal) => { + self.read_from_directory(proposal, persister).await + }, + ReceiveSession::UncheckedOriginalPayload(proposal) => { + self.check_proposal(proposal, persister).await + }, + ReceiveSession::MaybeInputsOwned(proposal) => { + self.check_inputs_not_owned(proposal, persister).await + }, + ReceiveSession::MaybeInputsSeen(proposal) => { + self.check_no_inputs_seen_before(proposal, persister).await + }, + ReceiveSession::OutputsUnknown(proposal) => { + self.identify_receiver_outputs(proposal, persister).await + }, + ReceiveSession::WantsOutputs(proposal) => { + self.commit_outputs(proposal, persister).await + }, + ReceiveSession::WantsInputs(proposal) => { + self.contribute_inputs(proposal, persister).await + }, + ReceiveSession::WantsFeeRange(proposal) => { + self.apply_fee_range(proposal, persister).await + }, + ReceiveSession::ProvisionalProposal(proposal) => { + self.finalize_proposal(proposal, persister).await + }, + ReceiveSession::PayjoinProposal(proposal) => { + self.send_payjoin_proposal(proposal, persister).await + }, + ReceiveSession::HasReplyableError(error) => { + self.handle_error(error, persister).await + }, + ReceiveSession::Monitor(proposal) => { + self.monitor_payjoin_proposal(proposal, persister).await + }, + ReceiveSession::Closed(_) => return Err(Error::InvalidAddress), + } + }; + res + } + + async fn read_from_directory( + &self, session: Receiver, persister: &KVStorePayjoinReceiverPersister, + ) -> Result<(), Error> { + let mut interrupt = self.stop_receiver.clone(); + let receiver = tokio::select! { + res = self.long_poll_fallback(session, &*persister) => res, + _ = interrupt.changed() => { + log_error!(self.logger, "Interrupted. Call the `resume` command to resume all sessions."); + return Err(Error::InvalidAddress); + } + }?; + self.check_proposal(receiver, &*persister).await + } + + async fn long_poll_fallback( + &self, session: Receiver, persister: &KVStorePayjoinReceiverPersister, + ) -> Result, Error> { + let payjoin_config = self.config.payjoin_config.as_ref().ok_or(Error::InvalidAddress)?; + let ohttp_relay = payjoin_config.ohttp_relay.clone(); + + let mut session = session; + loop { + let (req, context) = + session.create_poll_request(ohttp_relay.as_str()).map_err(|e| { + log_error!(self.logger, "Failed to create poll request: {}", e); + Error::InvalidAddress + })?; + log_debug!(self.logger, "Polling receive request..."); + let ohttp_response = self.post_request(req).await?; + let state_transition = session + .process_response( + ohttp_response + .bytes() + .await + .map_err(|e| { + log_error!(self.logger, "Failed to read response bytes: {}", e); + Error::InvalidAddress + })? + .to_vec() + .as_slice(), + context, + ) + .save_async(persister) + .await; + match state_transition { + Ok(OptionalTransitionOutcome::Progress(next_state)) => { + log_info!( + self.logger, + "Got a request from the sender. Responding with a Payjoin proposal." + ); + return Ok(next_state); + }, + Ok(OptionalTransitionOutcome::Stasis(current_state)) => { + session = current_state; + continue; + }, + Err(_) => return Err(Error::PersistenceFailed), + } + } + } + + async fn post_request(&self, req: payjoin::Request) -> Result { + let http_client = reqwest::Client::new(); + let client = http_client; + client + .post(req.url) + .header("Content-Type", req.content_type) + .body(req.body) + .send() + .await + .map_err(|e| { + log_error!(self.logger, "HTTP request failed: {}", e); + Error::InvalidAddress + }) + } + + async fn check_proposal( + &self, proposal: Receiver, + persister: &KVStorePayjoinReceiverPersister, + ) -> Result<(), Error> { + let proposal = proposal + .check_broadcast_suitability(None, |tx| { + self.chain_source + .can_broadcast_transaction(tx) + .map_err(|e| ImplementationError::from(e.to_string().as_str())) + }) + .save_async(persister) + .await + .map_err(|_| Error::PersistenceFailed)?; + + log_info!(self.logger, "Fallback transaction received. Consider broadcasting this to get paid if the Payjoin fails: {}", serialize_hex(&proposal.extract_tx_to_schedule_broadcast())); + self.check_inputs_not_owned(proposal, persister).await + } + + async fn check_inputs_not_owned( + &self, proposal: Receiver, persister: &KVStorePayjoinReceiverPersister, + ) -> Result<(), Error> { + let proposal = proposal + .check_inputs_not_owned(&mut |input| { + self.wallet + .is_mine(input.to_owned()) + .map_err(|e| ImplementationError::from(e.to_string().as_str())) + }) + .save_async(persister) + .await + .map_err(|_| Error::PersistenceFailed)?; + + self.check_no_inputs_seen_before(proposal, persister).await + } + + async fn check_no_inputs_seen_before( + &self, proposal: Receiver, persister: &KVStorePayjoinReceiverPersister, + ) -> Result<(), Error> { + let proposal = proposal + // TODO: DETERMINE IF SAVING THE INPUT AT THIS POINT IS NECESSARY FOR NOW WE JUST RETURN FALSE + // BUT I THINK IT WOULD BE BETTER TO SAVE IT SO THAT IF THE SESSION IS RESUMED WE CAN CHECK AGAIN + .check_no_inputs_seen_before(&mut |_| Ok(false)) + .save_async(persister) + .await + .map_err(|_| Error::PersistenceFailed)?; + self.identify_receiver_outputs(proposal, persister).await + } + + async fn identify_receiver_outputs( + &self, proposal: Receiver, persister: &KVStorePayjoinReceiverPersister, + ) -> Result<(), Error> { + let proposal = proposal + .identify_receiver_outputs(&mut |output_script| { + self.wallet + .is_mine(output_script.to_owned()) + .map_err(|e| ImplementationError::from(e.to_string().as_str())) + }) + .save_async(persister) + .await + .map_err(|_| Error::PersistenceFailed)?; + self.commit_outputs(proposal, persister).await + } + + async fn commit_outputs( + &self, proposal: Receiver, persister: &KVStorePayjoinReceiverPersister, + ) -> Result<(), Error> { + let proposal = proposal.commit_outputs().save_async(persister).await?; + self.contribute_inputs(proposal, persister).await + } + + async fn contribute_inputs( + &self, proposal: Receiver, persister: &KVStorePayjoinReceiverPersister, + ) -> Result<(), Error> { + let candidate_inputs = self.list_input_pairs()?; + + if candidate_inputs.is_empty() { + return Err({ + log_error!( + self.logger, + "No spendable UTXOs available in wallet. Cannot contribute inputs to payjoin." + ); + Error::InvalidAddress + }); + } + + let selected_input = + proposal.try_preserving_privacy(candidate_inputs).map_err(|_| Error::InvalidAddress)?; + let proposal = proposal + .contribute_inputs(vec![selected_input]) + .map_err(|_| Error::InvalidAddress)? + .commit_inputs() + .save_async(persister) + .await?; + self.apply_fee_range(proposal, persister).await + } + + fn list_input_pairs(&self) -> Result, Error> { + let unspent = self.wallet.list_unspent_utxos()?; + + let mut input_pairs = Vec::with_capacity(unspent.len()); + + for u in unspent { + let txin = TxIn { previous_output: u.outpoint, ..Default::default() }; + let psbtin = Input { witness_utxo: Some(u.output.clone()), ..Default::default() }; + let satisfaction_weight = Weight::from_wu(u.satisfaction_weight); + + let input_pair = + InputPair::new(txin, psbtin, Some(satisfaction_weight)).map_err(|e| { + log_error!(self.logger, "Failed to create InputPair: {}", e); + Error::InvalidAddress + })?; + + input_pairs.push(input_pair); + } + + Ok(input_pairs) + } + + async fn apply_fee_range( + &self, proposal: Receiver, persister: &KVStorePayjoinReceiverPersister, + ) -> Result<(), Error> { + let confirmation_target = ConfirmationTarget::OnchainPayment; + + let proposal = proposal + .apply_fee_range(None, None) + .save_async(persister) + .await + .map_err(|_| Error::PersistenceFailed)?; + + self.finalize_proposal(proposal, persister).await + } + + async fn finalize_proposal( + &self, proposal: Receiver, persister: &KVStorePayjoinReceiverPersister, + ) -> Result<(), Error> { + let proposal = proposal + .finalize_proposal(|psbt| { + self.wallet + .process_psbt(psbt.clone()) + .map_err(|e| ImplementationError::from(e.to_string().as_str())) + }) + .save_async(persister) + .await + .map_err(|_| Error::PersistenceFailed)?; + self.send_payjoin_proposal(proposal, persister).await + } + + async fn send_payjoin_proposal( + &self, proposal: Receiver, persister: &KVStorePayjoinReceiverPersister, + ) -> Result<(), Error> { + let payjoin_config = self.config.payjoin_config.as_ref().ok_or(Error::InvalidAddress)?; + let ohttp_relay = payjoin_config.ohttp_relay.clone(); + let (req, ohttp_ctx) = proposal.create_post_request(ohttp_relay.as_str()).map_err(|e| { + log_error!(self.logger, "v2 req extraction failed {}", e); + Error::InvalidAddress + })?; + let res = self.post_request(req).await?; + let payjoin_psbt = proposal.psbt().clone(); + let session = proposal + .process_response( + &res.bytes().await.map_err(|e| { + log_error!(self.logger, "Failed to read response bytes: {}", e); + Error::InvalidAddress + })?, + ohttp_ctx, + ) + .save_async(persister) + .await + .map_err(|_| Error::PersistenceFailed)?; + + log_info!( + self.logger, + "Response successful. Watch mempool for successful Payjoin. TXID: {}", + payjoin_psbt.extract_tx_unchecked_fee_rate().compute_txid() + ); + + return self.monitor_payjoin_proposal(session, persister).await; + } + + async fn handle_error( + &self, session: Receiver, persister: &KVStorePayjoinReceiverPersister, + ) -> Result<(), Error> { + let payjoin_config = self.config.payjoin_config.as_ref().ok_or(Error::InvalidAddress)?; + let ohttp_relay = payjoin_config.ohttp_relay.clone(); + + let (err_req, err_ctx) = session + .create_error_request(ohttp_relay.as_str()) + .map_err(|e| Error::InvalidAddress)?; + + let err_response = match self.post_request(err_req).await { + Ok(response) => response, + Err(e) => return Err(Error::InvalidAddress), + }; + + let err_bytes = match err_response.bytes().await { + Ok(bytes) => bytes.to_vec(), + Err(_) => return Err(Error::InvalidAddress), + }; + + if let Err(e) = + session.process_error_response(&err_bytes, err_ctx).save_async(persister).await + { + return Err(Error::InvalidAddress); + } + + Ok(()) + } + + async fn monitor_payjoin_proposal( + &self, proposal: Receiver, persister: &KVStorePayjoinReceiverPersister, + ) -> Result<(), Error> { + // On a session resumption, the receiver will resume again in this state. + let poll_interval = tokio::time::Duration::from_millis(200); + let timeout_duration = tokio::time::Duration::from_secs(5); + + let mut interval = tokio::time::interval(poll_interval); + interval.tick().await; + + log_debug!(self.logger, "Polling for payment confirmation"); + + let result = tokio::time::timeout(timeout_duration, async { + loop { + interval.tick().await; + let check_result = proposal + .check_payment(|txid| { + self.chain_source + .get_transaction(&txid) + .map_err(|e| ImplementationError::from(e.to_string().as_str())) + }) + .save_async(persister) + .await; + + match check_result { + Ok(_) => { + log_info!(self.logger, "Payjoin transaction detected in the mempool!"); + return Ok(()); + }, + Err(_) => { + // keep polling + continue; + }, + } + } + }) + .await; + + match result { + Ok(ok) => ok, + Err(_) => Err({ + log_error!( + self.logger, + "Timeout waiting for payment confirmation after {:?}", + timeout_duration + ); + Error::InvalidAddress + }), + } + } + + async fn resume_payjoins(&self) -> Result<(), Error> { + let recv_session_ids = self + .payjoin_session_store + .list_filter(|p| { + p.direction == PayjoinDirection::Receive && p.status == PayjoinStatus::Active + }) + .iter() + .map(|s| s.session_id.clone()) + .collect::>(); + + if recv_session_ids.is_empty() { + log_info!(self.logger, "No sessions to resume."); + return Ok(()); + } + + let mut tasks = Vec::new(); + + // Process receiver sessions + for session_id in recv_session_ids { + let self_clone = self.clone(); + // Create a persister for this session + let recv_persister = match KVStorePayjoinReceiverPersister::from_session( + session_id.clone(), + self.kv_store.clone(), + self.logger.clone(), + ) { + Ok(p) => p, + Err(e) => { + log_error!( + self.logger, + "Failed to create persister for session {:?}: {:?}", + session_id, + e + ); + continue; + }, + }; + + match replay_receiver_event_log_async(&recv_persister).await { + Ok((receiver_state, _)) => { + tasks.push(tokio::spawn(async move { + self_clone.process_receiver_session(receiver_state, &recv_persister).await + })); + }, + Err(e) => { + log_error!( + self.logger, + "An error {:?} occurred while replaying receiver session", + e + ); + self.close_failed_session(&recv_persister, &session_id, "receiver").await; + }, + } + } + + let mut interrupt = self.stop_receiver.clone(); + tokio::select! { + _ = async { + for task in tasks { + let _ = task.await; + } + } => { + println!("All payjoin resumed sessions completed."); + } + _ = interrupt.changed() => { + println!("Resumed payjoin sessions were interrupted."); + } + } + Ok(()) + } + + async fn close_failed_session

(&self, persister: &P, session_id: &PaymentId, role: &str) + where + P: AsyncSessionPersister, + { + if let Err(close_err) = AsyncSessionPersister::close(persister).await { + log_error!( + self.logger, + "Failed to close {} session {}: {:?}", + role, + session_id, + close_err + ); + } else { + log_error!(self.logger, "Closed failed {} session: {}", role, session_id); + } + } +} diff --git a/src/payment/payjoin_payment/mod.rs b/src/payment/payjoin_payment/mod.rs new file mode 100644 index 000000000..e3dc50d8e --- /dev/null +++ b/src/payment/payjoin_payment/mod.rs @@ -0,0 +1,10 @@ +// This file is Copyright its original authors, visible in version control history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license , at your option. You may not use this file except in +// accordance with one or both of these licenses. + +pub(crate) mod manager; +pub(crate) mod payjoin_session; +pub(crate) mod persist; diff --git a/src/payment/payjoin_payment/payjoin_session.rs b/src/payment/payjoin_payment/payjoin_session.rs new file mode 100644 index 000000000..7c35425ce --- /dev/null +++ b/src/payment/payjoin_payment/payjoin_session.rs @@ -0,0 +1,249 @@ +// This file is Copyright its original authors, visible in version control history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license , at your option. You may not use this file except in +// accordance with one or both of these licenses. + +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + +use lightning::ln::channelmanager::PaymentId; +use lightning::ln::msgs::DecodeError; +use lightning::util::ser::{Readable, Writeable}; +use lightning::{ + _init_and_read_len_prefixed_tlv_fields, impl_writeable_tlv_based, + impl_writeable_tlv_based_enum, write_tlv_fields, +}; + +use crate::data_store::{StorableObject, StorableObjectUpdate}; + +/// Represents a payjoin session with persisted events +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct PayjoinSession { + /// Session identifier (uses PaymentId from PaymentDetails) + pub session_id: PaymentId, + + /// Direction of the payjoin (Send or Receive) + pub direction: PayjoinDirection, + + /// HPKE public key of receiver (only for sender sessions) + pub receiver_pubkey: Option>, + + /// Serialized session events as JSON strings + pub events: Vec, + + /// Current status of the session + pub status: PayjoinStatus, + + /// Unix timestamp of session completion (if completed) + pub completed_at: Option, + + /// The timestamp, in seconds since start of the UNIX epoch, when this entry was last updated. + pub latest_update_timestamp: u64, +} + +impl PayjoinSession { + pub fn new( + session_id: PaymentId, direction: PayjoinDirection, receiver_pubkey: Option>, + ) -> Self { + let latest_update_timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or(Duration::from_secs(0)) + .as_secs(); + Self { + session_id, + direction, + receiver_pubkey, + events: Vec::new(), + status: PayjoinStatus::Active, + completed_at: None, + latest_update_timestamp, + } + } +} + +impl Writeable for PayjoinSession { + fn write( + &self, writer: &mut W, + ) -> Result<(), lightning::io::Error> { + write_tlv_fields!(writer, { + (0, self.session_id, required), + (2, self.direction, required), + (4, self.receiver_pubkey, option), + (6, self.events, required_vec), + (8, self.status, required), + (10, self.completed_at, required), + (12, self.latest_update_timestamp, required), + }); + Ok(()) + } +} + +impl Readable for PayjoinSession { + fn read(reader: &mut R) -> Result { + let unix_time_secs = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or(Duration::from_secs(0)) + .as_secs(); + _init_and_read_len_prefixed_tlv_fields!(reader, { + (0, session_id, required), + (2, direction, required), + (4, receiver_pubkey, option), + (6, events, required_vec), + (8, status, required), + (10, completed_at, option), + (12, latest_update_timestamp, (default_value, unix_time_secs)) + }); + + let session_id: PaymentId = session_id.0.ok_or(DecodeError::InvalidValue)?; + let direction: PayjoinDirection = direction.0.ok_or(DecodeError::InvalidValue)?; + let status: PayjoinStatus = status.0.ok_or(DecodeError::InvalidValue)?; + let latest_update_timestamp: u64 = + latest_update_timestamp.0.ok_or(DecodeError::InvalidValue)?; + + Ok(PayjoinSession { + session_id, + direction, + receiver_pubkey, + events, + status, + completed_at, + latest_update_timestamp, + }) + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum PayjoinDirection { + /// The session is for sending a payment + Send, + /// The session is for receiving a payment + Receive, +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum PayjoinStatus { + /// The session is active + Active, + /// The session has completed successfully + Completed, + /// The session has failed + Failed, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct SerializedSessionEvent { + /// JSON representation of the event + pub event_json: String, + /// Unix timestamp of when the event occurred + pub created_at: u64, +} + +impl_writeable_tlv_based!(SerializedSessionEvent, { + (0, event_json, required), + (2, created_at, required), +}); + +impl_writeable_tlv_based_enum!(PayjoinDirection, + (0, Send) => {}, + (2, Receive) => {} +); + +impl_writeable_tlv_based_enum!(PayjoinStatus, + (0, Active) => {}, + (2, Completed) => {}, + (4, Failed) => {} +); + +/// Represents a payjoin session with persisted events +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) struct PayjoinSessionUpdate { + pub session_id: PaymentId, + pub receiver_pubkey: Option>>, + pub events: Option>, + pub status: Option, + pub completed_at: Option>, +} + +impl PayjoinSessionUpdate { + pub fn new(id: PaymentId) -> Self { + Self { + session_id: id, + receiver_pubkey: None, + events: None, + status: None, + completed_at: None, + } + } +} + +impl From<&PayjoinSession> for PayjoinSessionUpdate { + fn from(value: &PayjoinSession) -> Self { + Self { + session_id: value.session_id, + receiver_pubkey: Some(value.receiver_pubkey.clone()), + events: Some(value.events.clone()), + status: Some(value.status), + completed_at: Some(value.completed_at), + } + } +} + +impl StorableObject for PayjoinSession { + type Id = PaymentId; + type Update = PayjoinSessionUpdate; + + fn id(&self) -> Self::Id { + self.session_id + } + + fn update(&mut self, update: &Self::Update) -> bool { + debug_assert_eq!( + self.session_id, update.session_id, + "We should only ever override data for the same id" + ); + + let mut updated = false; + + macro_rules! update_if_necessary { + ($val:expr, $update:expr) => { + if $val != $update { + $val = $update; + updated = true; + } + }; + } + + if let Some(receiver_pubkey_opt) = &update.receiver_pubkey { + update_if_necessary!(self.receiver_pubkey, receiver_pubkey_opt.clone()); + } + if let Some(events_opt) = &update.events { + update_if_necessary!(self.events, events_opt.clone()); + } + if let Some(status_opt) = update.status { + update_if_necessary!(self.status, status_opt); + } + if let Some(completed_at_opt) = update.completed_at { + update_if_necessary!(self.completed_at, completed_at_opt); + } + + if updated { + self.latest_update_timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or(Duration::from_secs(0)) + .as_secs(); + } + + updated + } + + fn to_update(&self) -> Self::Update { + self.into() + } +} + +impl StorableObjectUpdate for PayjoinSessionUpdate { + fn id(&self) -> ::Id { + self.session_id + } +} diff --git a/src/payment/payjoin_payment/persist.rs b/src/payment/payjoin_payment/persist.rs new file mode 100644 index 000000000..30fa57a5f --- /dev/null +++ b/src/payment/payjoin_payment/persist.rs @@ -0,0 +1,227 @@ +// This file is Copyright its original authors, visible in version control history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license , at your option. You may not use this file except in +// accordance with one or both of these licenses. + +use crate::payment::payjoin_payment::payjoin_session::{PayjoinStatus, SerializedSessionEvent}; +use crate::Error; +use std::{ + sync::Arc, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; + +use lightning::util::ser::Readable; +use lightning::{io::Cursor, ln::channelmanager::PaymentId, util::persist::KVStoreSync}; + +use crate::io::{ + PAYJOIN_SESSION_STORE_PRIMARY_NAMESPACE, PAYJOIN_SESSION_STORE_SECONDARY_NAMESPACE, +}; +use crate::logger::{log_error, LdkLogger, Logger}; +use crate::payment::payjoin_payment::payjoin_session::{PayjoinDirection, PayjoinSession}; +use crate::types::{DynStore, PayjoinSessionStore}; + +use payjoin::persist::AsyncSessionPersister; +use payjoin::receive::v2::SessionEvent as ReceiverSessionEvent; + +pub(crate) struct KVStorePayjoinSenderPersister { + session_id: PaymentId, + kv_store: Arc, +} + +pub(crate) struct KVStorePayjoinReceiverPersister { + session_id: PaymentId, + kv_store: Arc, +} + +impl KVStorePayjoinReceiverPersister { + pub fn new( + session_id: PaymentId, kv_store: Arc, logger: Arc, + ) -> Result { + let sessions = Self::load_all_sessions(&kv_store, &logger)?; + let data_store = Arc::new(PayjoinSessionStore::new( + sessions, + PAYJOIN_SESSION_STORE_PRIMARY_NAMESPACE.to_string(), + PAYJOIN_SESSION_STORE_SECONDARY_NAMESPACE.to_string(), + kv_store, + logger, + )); + + let session = PayjoinSession::new(session_id, PayjoinDirection::Receive, None); + + data_store.insert(session)?; + + Ok(Self { session_id, kv_store: data_store }) + } + + /// Reconstruct persister from existing session + pub fn from_session( + session_id: PaymentId, kv_store: Arc, logger: Arc, + ) -> Result { + let sessions = Self::load_all_sessions(&kv_store, &logger)?; + let data_store = Arc::new(PayjoinSessionStore::new( + sessions, + PAYJOIN_SESSION_STORE_PRIMARY_NAMESPACE.to_string(), + PAYJOIN_SESSION_STORE_SECONDARY_NAMESPACE.to_string(), + kv_store, + logger, + )); + + if data_store.get(&session_id).is_none() { + return Err(Error::InvalidPaymentId); + } + + Ok(Self { session_id, kv_store: data_store }) + } + + /// Load all sessions from KV store + fn load_all_sessions( + kv_store: &Arc, logger: &Arc, + ) -> Result, Error> { + let keys = KVStoreSync::list( + &**kv_store, + PAYJOIN_SESSION_STORE_PRIMARY_NAMESPACE, + PAYJOIN_SESSION_STORE_SECONDARY_NAMESPACE, + ) + .map_err(|e| { + log_error!(logger, "Failed to list payjoin sessions: {:?}", e); + Error::PersistenceFailed + })?; + + let mut sessions = Vec::new(); + for key in keys { + match KVStoreSync::read( + &**kv_store, + PAYJOIN_SESSION_STORE_PRIMARY_NAMESPACE, + PAYJOIN_SESSION_STORE_SECONDARY_NAMESPACE, + &key, + ) { + Ok(data) => { + let mut reader = Cursor::new(&data[..]); + match PayjoinSession::read(&mut reader) { + Ok(session) => sessions.push(session), + Err(e) => { + log_error!( + logger, + "Failed to deserialize PayjoinSession for key {}: {:?}. Skipping corrupted session.", + key, e + ); + continue; + }, + } + }, + Err(e) => { + log_error!( + logger, + "Failed to read PayjoinSession data for key {}: {:?}", + key, + e + ); + continue; + }, + } + } + + Ok(sessions) + } + + /// Get all active Receiver session IDs + pub fn get_active_session_ids( + kv_store: Arc, logger: Arc, + ) -> Result, Error> { + let sessions = Self::load_all_sessions(&kv_store, &logger)?; + Ok(sessions + .into_iter() + .filter(|s| { + s.direction == PayjoinDirection::Receive && s.status == PayjoinStatus::Active + }) + .map(|s| s.session_id) + .collect()) + } + + /// Get all inactive Receiver sessions (for cleanup) + pub fn get_inactive_sessions( + kv_store: Arc, logger: Arc, + ) -> Result, Error> { + let sessions = Self::load_all_sessions(&kv_store, &logger)?; + Ok(sessions + .into_iter() + .filter(|s| { + s.direction == PayjoinDirection::Receive + && s.status != PayjoinStatus::Active + && s.completed_at.is_some() + }) + .map(|s| (s.session_id, s.completed_at.unwrap())) + .collect()) + } +} + +impl AsyncSessionPersister for KVStorePayjoinReceiverPersister { + type SessionEvent = ReceiverSessionEvent; + type InternalStorageError = Error; + + fn save_event( + &self, event: Self::SessionEvent, + ) -> impl std::future::Future> + Send { + async move { + let mut session = self.kv_store.get(&self.session_id).ok_or(Error::InvalidPaymentId)?; + + let event_json = serde_json::to_string(&event).map_err(|_| Error::PersistenceFailed)?; + + session.events.push(SerializedSessionEvent { + event_json, + created_at: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or(Duration::from_secs(0)) + .as_secs(), + }); + + self.kv_store.insert_or_update(session)?; + + Ok(()) + } + } + + fn load( + &self, + ) -> impl std::future::Future< + Output = Result< + Box + Send>, + Self::InternalStorageError, + >, + > + Send { + async move { + let session = self.kv_store.get(&self.session_id).ok_or(Error::InvalidPaymentId)?; + + let events: Vec = session + .events + .iter() + .map(|e| serde_json::from_str(&e.event_json)) + .collect::, _>>() + .map_err(|_| Error::PersistenceFailed)?; + + Ok(Box::new(events.into_iter()) as Box + Send>) + } + } + + fn close( + &self, + ) -> impl std::future::Future> + Send { + async move { + let mut session = self.kv_store.get(&self.session_id).ok_or(Error::InvalidPaymentId)?; + + session.completed_at = Some(now()); + session.status = PayjoinStatus::Completed; + + self.kv_store.insert_or_update(session)?; + + Ok(()) + } + } +} + +// Helper function for timestamp +fn now() -> u64 { + SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or(Duration::from_secs(0)).as_secs() +} diff --git a/src/types.rs b/src/types.rs index 614efd90e..e69153f84 100644 --- a/src/types.rs +++ b/src/types.rs @@ -39,6 +39,7 @@ use crate::data_store::DataStore; use crate::fee_estimator::OnchainFeeEstimator; use crate::logger::Logger; use crate::message_handler::NodeCustomMessageHandler; +use crate::payment::payjoin_payment::payjoin_session::PayjoinSession; use crate::payment::PaymentDetails; use crate::runtime::RuntimeSpawner; @@ -321,6 +322,8 @@ pub(crate) type BumpTransactionEventHandler = pub(crate) type PaymentStore = DataStore>; +pub(crate) type PayjoinSessionStore = DataStore>; + /// A local, potentially user-provided, identifier of a channel. /// /// By default, this will be randomly generated for the user to ensure local uniqueness. diff --git a/src/wallet/mod.rs b/src/wallet/mod.rs index 5fd7b3d8e..bee19d99c 100644 --- a/src/wallet/mod.rs +++ b/src/wallet/mod.rs @@ -806,6 +806,98 @@ impl Wallet { Ok(tx) } + + /// Check if a script belongs to this wallet + pub fn is_mine(&self, script: ScriptBuf) -> Result { + let locked_wallet = self.inner.lock().unwrap(); + Ok(locked_wallet.is_mine(script)) + } + + pub fn process_psbt(&self, mut psbt: Psbt) -> Result { + let locked_wallet = self.inner.lock().unwrap(); + + let mut sign_options = SignOptions::default(); + sign_options.trust_witness_utxo = true; + + locked_wallet.sign(&mut psbt, sign_options).map_err(|e| { + log_error!(self.logger, "Failed to sign PSBT: {}", e); + Error::WalletOperationFailed + })?; + + // Return the signed PSBT (not extracted transaction) + Ok(psbt) + } + + pub fn list_unspent_utxos(&self) -> Result, Error> { + let locked_wallet = self.inner.lock().unwrap(); + + let mut utxos = Vec::new(); + + for u in locked_wallet.list_unspent() { + let script_pubkey = &u.txout.script_pubkey; + + match script_pubkey.witness_version() { + Some(version @ WitnessVersion::V0) => { + // P2WPKH handling + let witness_bytes = &script_pubkey.as_bytes()[2..]; + let witness_program = + WitnessProgram::new(version, witness_bytes).map_err(|e| { + log_error!(self.logger, "Failed to retrieve script payload: {}", e); + Error::InvalidAddress + })?; + + let wpkh = WPubkeyHash::from_slice(&witness_program.program().as_bytes()) + .map_err(|e| { + log_error!(self.logger, "Failed to retrieve script payload: {}", e); + Error::InvalidAddress + })?; + + let utxo = Utxo::new_v0_p2wpkh(u.outpoint, u.txout.value, &wpkh); + utxos.push(utxo); + }, + Some(version @ WitnessVersion::V1) => { + // P2TR (Taproot) handling + let witness_bytes = &script_pubkey.as_bytes()[2..]; + let witness_program = + WitnessProgram::new(version, witness_bytes).map_err(|e| { + log_error!(self.logger, "Failed to retrieve script payload: {}", e); + Error::InvalidAddress + })?; + + XOnlyPublicKey::from_slice(&witness_program.program().as_bytes()).map_err( + |e| { + log_error!(self.logger, "Failed to retrieve script payload: {}", e); + Error::InvalidAddress + }, + )?; + + let utxo = Utxo { + outpoint: u.outpoint, + output: TxOut { + value: u.txout.value, + script_pubkey: ScriptBuf::new_witness_program(&witness_program), + }, + satisfaction_weight: 1 /* empty script_sig */ * WITNESS_SCALE_FACTOR as u64 + + 1 /* witness items */ + 1 /* schnorr sig len */ + 64, // schnorr sig + }; + utxos.push(utxo); + }, + Some(version) => { + log_error!(self.logger, "Unexpected witness version: {}", version); + continue; + }, + None => { + log_error!( + self.logger, + "Tried to use a non-witness script. This must never happen." + ); + panic!("Tried to use a non-witness script. This must never happen."); + }, + } + } + + Ok(utxos) + } } impl Listen for Wallet {