diff --git a/contracts/src/Chainvoice.sol b/contracts/src/Chainvoice.sol index 042397b4..bdf637bd 100644 --- a/contracts/src/Chainvoice.sol +++ b/contracts/src/Chainvoice.sol @@ -100,6 +100,27 @@ contract Chainvoice { _entered = false; } + modifier validInvoice(uint256 invoiceId) { + require(invoiceId < invoices.length, "Invalid invoice ID"); + _; + } + + modifier onlyCreator(uint256 invoiceId) { + require(msg.sender == invoices[invoiceId].from, "Only invoice creator can cancel"); + _; + } + + modifier onlyPayer(uint256 invoiceId) { + require(msg.sender == invoices[invoiceId].to, "Not authorized"); + _; + } + + modifier invoiceActive(uint256 invoiceId) { + require(!invoices[invoiceId].isPaid, "Already paid"); + require(!invoices[invoiceId].isCancelled, "Invoice is cancelled"); + _; + } + // Constants uint256 public constant MAX_BATCH = 50; @@ -217,12 +238,8 @@ contract Chainvoice { } // ========== Cancel single invoice ========== - function cancelInvoice(uint256 invoiceId) external { - if (invoiceId >= invoices.length) revert InvalidInvoiceId(); + function cancelInvoice(uint256 invoiceId) external validInvoice(invoiceId) onlyCreator(invoiceId) invoiceActive(invoiceId) { InvoiceDetails storage invoice = invoices[invoiceId]; - - if (msg.sender != invoice.from) revert NotInvoiceCreator(); - if (invoice.isPaid || invoice.isCancelled) revert InvoiceNotCancellable(); invoice.isCancelled = true; @@ -235,14 +252,9 @@ contract Chainvoice { } // ========== Pay single invoice ========== - function payInvoice(uint256 invoiceId) external payable nonReentrant { - if (invoiceId >= invoices.length) revert InvalidInvoiceId(); + function payInvoice(uint256 invoiceId) external payable nonReentrant validInvoice(invoiceId) onlyPayer(invoiceId) invoiceActive(invoiceId) { InvoiceDetails storage invoice = invoices[invoiceId]; - if (msg.sender != invoice.to) revert NotAuthorizedPayer(); - if (invoice.isPaid) revert InvoiceAlreadyPaid(); - if (invoice.isCancelled) revert InvoiceCancelledError(); - // Effects first for CEI (mark paid, bump fees), then interactions invoice.isPaid = true; if (invoice.tokenAddress == address(0)) { @@ -252,10 +264,11 @@ contract Chainvoice { (bool sent, ) = payable(invoice.from).call{value: invoice.amountDue}(""); if (!sent) revert NativeTransferFailed(); } else { - if (msg.value != fee) revert FeeMustBeNative(); - if (IERC20(invoice.tokenAddress).allowance(msg.sender, address(this)) < invoice.amountDue) { - revert InsufficientAllowance(); - } + require(msg.value == fee, "Must pay fee in native token"); + require( + IERC20(invoice.tokenAddress).allowance(msg.sender, address(this)) >= invoice.amountDue, + "Insufficient allowance" + ); accumulatedFees += fee; bool transferSuccess = IERC20(invoice.tokenAddress).transferFrom( @@ -274,7 +287,6 @@ contract Chainvoice { invoice.tokenAddress ); } - // ========== Batch pay (all-or-nothing) ========== function payInvoicesBatch(uint256[] calldata invoiceIds) external payable nonReentrant { uint256 n = invoiceIds.length; @@ -351,6 +363,7 @@ contract Chainvoice { ) external view + validInvoice(invoiceId) returns (bool canPay, uint256 payerBalance, uint256 allowanceAmount) { if (invoiceId >= invoices.length) revert InvalidInvoiceId(); @@ -394,8 +407,7 @@ contract Chainvoice { return result; } - function getInvoice(uint256 invoiceId) external view returns (InvoiceDetails memory) { - if (invoiceId >= invoices.length) revert InvalidInvoiceId(); + function getInvoice(uint256 invoiceId) external view validInvoice(invoiceId) returns (InvoiceDetails memory) { return invoices[invoiceId]; }