Skip to content

Add swap fee #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 54 additions & 3 deletions src/ForwarderLogic.sol
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@ contract ForwarderLogic is IForwarderLogic {
using EnumerableSet for EnumerableSet.AddressSet;
using SafeERC20 for IERC20;

uint256 internal constant BPS = 10000;

address private immutable _router;
address private _feeReceiver;

EnumerableSet.AddressSet private _trustedRouter;
mapping(address => bool) private _blacklist;

constructor(address router) {
constructor(address router, address feeReceiver) {
if (router == address(0)) revert ForwarderLogic__InvalidRouter();
_router = router;
_setFeeReceiver(feeReceiver);
}

/**
Expand All @@ -49,13 +53,23 @@ contract ForwarderLogic is IForwarderLogic {
return _blacklist[account];
}

/**
* @dev Returns the fee receiver address.
*/
function getFeeReceiver() external view override returns (address) {
return _feeReceiver;
}

/**
* @dev Swaps an exact amount of tokenIn for as much tokenOut as possible using an external router.
* The function will simply forward the call to the router and return the amount of tokenIn and tokenOut swapped.
*
* Requirements:
* - The caller must be the router.
* - The data must be formatted using abi.encodePacked(approval, router, routerData).
* - The third party router must be trusted.
* - The data must be formatted using abi.encodePacked(uint128(feeAmount), approval, router, routerData).
* - The fee amount must be less than or equal to the amountIn.
* - The router data must use at most `amountIn - feeAmount` of tokenIn.
*/
function swapExactIn(
address tokenIn,
Expand All @@ -71,10 +85,21 @@ contract ForwarderLogic is IForwarderLogic {

address approval = address(uint160(bytes20(data[0:20])));
address router = address(uint160(bytes20(data[20:40])));
bytes memory routerData = data[40:];
uint256 feePercent = uint256(uint16(bytes2(data[40:42])));
bytes memory routerData = data[42:];

RouterLib.transfer(_router, tokenIn, from, address(this), amountIn);

uint256 feeAmount = (amountIn * feePercent) / BPS;
if (feeAmount > 0) {
amountIn -= feeAmount;

address feeReceiver = _feeReceiver;

TokenLib.transfer(tokenIn, feeReceiver, feeAmount);
emit FeeSent(tokenIn, from, feeReceiver, feeAmount);
}

SafeERC20.forceApprove(IERC20(tokenIn), approval, amountIn);

_call(router, routerData);
Expand Down Expand Up @@ -140,6 +165,32 @@ contract ForwarderLogic is IForwarderLogic {
emit BlacklistUpdated(account, blacklisted);
}

/**
* @dev Updates the fee receiver.
*
* Requirements:
* - The caller must be the router owner.
*/
function setFeeReceiver(address feeReceiver) external override {
if (msg.sender != Ownable(_router).owner()) revert ForwarderLogic__OnlyRouterOwner();

_setFeeReceiver(feeReceiver);
}

/**
* @dev Sets the fee receiver.
*
* Requirements:
* - The fee receiver must not be the zero address.
*/
function _setFeeReceiver(address feeReceiver) private {
if (feeReceiver == address(0)) revert ForwarderLogic__InvalidFeeReceiver();

_feeReceiver = feeReceiver;

emit FeeReceiverSet(feeReceiver);
}

/**
* @dev Calls the target contract with the provided data.
*
Expand Down
4 changes: 1 addition & 3 deletions src/Router.sol
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,8 @@ contract Router is Ownable2Step, ReentrancyGuard, IRouter {

uint256 balance = TokenLib.universalBalanceOf(tokenOut, recipient);

address logic_ = logic; // avoid stack too deep error

(totalIn, totalOut) =
RouterLib.swap(_allowances, tokenIn, tokenOut, amountIn, amountOut, from, recipient, route, exactIn, logic_);
RouterLib.swap(_allowances, tokenIn, tokenOut, amountIn, amountOut, from, recipient, route, exactIn, logic);

if (recipient == address(this)) {
totalOut = _verifySwap(tokenOut, recipient, balance, amountOut, totalOut);
Expand Down
18 changes: 9 additions & 9 deletions src/RouterAdapter.sol
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,23 @@ abstract contract RouterAdapter {
* @dev Returns the amount of tokenIn needed to get amountOut from the pair.
*
* Requirements:
* - The id of the flags must be valid.
* - The id of the flags must be valid and not the FEE_ID.
*/
function _getAmountIn(address pair, uint256 flags, uint256 amountOut) internal returns (uint256) {
function _getAmountIn(address pair, uint256 flags, uint256 amountOut) internal returns (uint256 amountIn) {
uint256 id = Flags.id(flags);

if (id == Flags.UNISWAP_V2_ID) {
return _getAmountInUV2(pair, flags, amountOut);
amountIn = _getAmountInUV2(pair, flags, amountOut);
} else if (id == Flags.LFJ_LEGACY_LIQUIDITY_BOOK_ID) {
return _getAmountInLegacyLB(pair, flags, amountOut);
amountIn = _getAmountInLegacyLB(pair, flags, amountOut);
} else if (id == Flags.LFJ_LIQUIDITY_BOOK_ID) {
return _getAmountInLB(pair, flags, amountOut);
amountIn = _getAmountInLB(pair, flags, amountOut);
} else if (id == Flags.UNISWAP_V3_ID) {
return _getAmountInUV3(pair, flags, amountOut);
amountIn = _getAmountInUV3(pair, flags, amountOut);
} else if (id == Flags.LFJ_TOKEN_MILL_ID) {
return _getAmountInTM(pair, flags, amountOut);
amountIn = _getAmountInTM(pair, flags, amountOut);
} else if (id == Flags.LFJ_TOKEN_MILL_V2_ID) {
return _getAmountInTMV2(pair, flags, amountOut);
amountIn = _getAmountInTMV2(pair, flags, amountOut);
} else {
revert RouterAdapter__InvalidId();
}
Expand All @@ -82,7 +82,7 @@ abstract contract RouterAdapter {
* @dev Swaps tokens from the sender to the recipient.
*
* Requirements:
* - The id of the flags must be valid.
* - The id of the flags must be valid and not the FEE_ID.
*/
function _swap(address pair, address tokenIn, uint256 amountIn, address recipient, uint256 flags)
internal
Expand Down
134 changes: 113 additions & 21 deletions src/RouterLogic.sol
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,27 @@ contract RouterLogic is RouterAdapter, IRouterLogic {

uint256 internal constant BPS = 10000;

address private _feeReceiver;

/**
* @dev Constructor for the RouterLogic contract.
*
* Requirements:
* - The router address must be a contract with code.
* - The fee receiver address must not be the zero address.
*/
constructor(address router, address routerV2_0) RouterAdapter(routerV2_0) {
constructor(address router, address routerV2_0, address feeReceiver) RouterAdapter(routerV2_0) {
if (router.code.length == 0) revert RouterLogic__InvalidRouter();

_router = router;
_setFeeReceiver(feeReceiver);
}

/**
* @dev Returns the fee receiver address.
*/
function getFeeReceiver() external view override returns (address) {
return _feeReceiver;
}

/**
Expand All @@ -55,23 +66,24 @@ contract RouterLogic is RouterAdapter, IRouterLogic {
address to,
bytes calldata route
) external override returns (uint256, uint256) {
(uint256 ptr, uint256 nbTokens, uint256 nbSwaps) = _startAndVerify(route, tokenIn, tokenOut);
(uint256 feePtr, uint256 ptr, uint256 nbTokens, uint256 nbSwaps) = _startAndVerify(route, tokenIn, tokenOut);

uint256 feeAmount = _getFee(route, feePtr, amountIn, true);
uint256 amountInWithoutFee = amountIn - feeAmount;

_sendFee(route, 0, from, feeAmount);

uint256[] memory balances = new uint256[](nbTokens);

balances[0] = amountIn;
uint256 total = amountIn;
balances[0] = amountInWithoutFee;
uint256 total = amountInWithoutFee;

bytes32 value;
{
address from_ = from;
address to_ = to;
for (uint256 i; i < nbSwaps; i++) {
(ptr, value) = PackedRoute.next(route, ptr);

unchecked {
total += _swapExactInSingle(route, balances, from_, to_, value);
}
for (uint256 i; i < nbSwaps; i++) {
(ptr, value) = PackedRoute.next(route, ptr);

unchecked {
total += _swapExactInSingle(route, balances, from, to, value);
}
}

Expand Down Expand Up @@ -108,24 +120,27 @@ contract RouterLogic is RouterAdapter, IRouterLogic {
address to,
bytes calldata route
) external override returns (uint256 totalIn, uint256 totalOut) {
(uint256 ptr, uint256 nbTokens, uint256 nbSwaps) = _startAndVerify(route, tokenIn, tokenOut);
(uint256 feePtr, uint256 ptr, uint256 nbTokens, uint256 nbSwaps) = _startAndVerify(route, tokenIn, tokenOut);

if (PackedRoute.isTransferTax(route)) revert RouterLogic__TransferTaxNotSupported();

(uint256 amountIn, uint256[] memory amountsIn) = _getAmountsIn(route, amountOut, nbTokens, nbSwaps);

if (amountIn > amountInMax) revert RouterLogic__ExceedsMaxAmountIn(amountIn, amountInMax);
uint256 feeAmount = _getFee(route, feePtr, amountIn, false);
uint256 amountInWithFee = amountIn + feeAmount;

if (amountInWithFee > amountInMax) revert RouterLogic__ExceedsMaxAmountIn(amountInWithFee, amountInMax);

_sendFee(route, 0, from, feeAmount);

bytes32 value;
address from_ = from;
address to_ = to;
for (uint256 i; i < nbSwaps; i++) {
(ptr, value) = PackedRoute.next(route, ptr);

_swapExactOutSingle(route, nbTokens, from_, to_, value, amountsIn[i]);
_swapExactOutSingle(route, nbTokens, from, to, value, amountsIn[i]);
}

return (amountIn, amountOut);
return (amountInWithFee, amountOut);
}

/**
Expand All @@ -140,6 +155,17 @@ contract RouterLogic is RouterAdapter, IRouterLogic {
token == address(0) ? TokenLib.transferNative(to, amount) : TokenLib.transfer(token, to, amount);
}

/**
* @dev Sets the fee receiver address.
*
* Requirements:
* - The caller must be the router owner.
*/
function setFeeReceiver(address feeReceiver) external override {
if (msg.sender != Ownable(_router).owner()) revert RouterLogic__OnlyRouterOwner();
_setFeeReceiver(feeReceiver);
}

/**
* @dev Helper function to check if the amount is valid.
*
Expand All @@ -163,19 +189,60 @@ contract RouterLogic is RouterAdapter, IRouterLogic {
function _startAndVerify(bytes calldata route, address tokenIn, address tokenOut)
private
view
returns (uint256 ptr, uint256 nbTokens, uint256 nbSwaps)
returns (uint256 feePtr, uint256 ptr, uint256 nbTokens, uint256 nbSwaps)
{
if (msg.sender != _router) revert RouterLogic__OnlyRouter();

(ptr, nbTokens, nbSwaps) = PackedRoute.start(route);

if (nbTokens < 2) revert RouterLogic__InsufficientTokens();
if (nbSwaps == 0) revert RouterLogic__ZeroSwap();

(uint256 nextPtr, bytes32 value) = PackedRoute.next(route, ptr);
uint256 flags = PackedRoute.getFlags(value);

if (Flags.id(flags) == Flags.FEE_ID) {
if (nbSwaps < 2) revert RouterLogic__ZeroSwap();
unchecked {
--nbSwaps;
}
feePtr = ptr;
ptr = nextPtr;
} else {
if (nbSwaps == 0) revert RouterLogic__ZeroSwap();
}

if (PackedRoute.token(route, 0) != tokenIn) revert RouterLogic__InvalidTokenIn();
if (PackedRoute.token(route, nbTokens - 1) != tokenOut) revert RouterLogic__InvalidTokenOut();
}

/**
* @dev Returns the fee amount added on the swap.
* The fee is calculated as follows:
* - if `isSwapExactIn`, the fee is calculated as `(amountIn * feePercent) / BPS`
* else, the fee is calculated as `(amountIn * BPS) / (BPS - feePercent)`
*
* Requirements:
* - The data must use a the valid format: `(address(0), feePercent, Flags.FEE_ID, 0, 0)`
* - The feePercent must be greater than 0 and less than BPS.
*/
function _getFee(bytes calldata route, uint256 feePtr, uint256 amountIn, bool isSwapExactIn)
private
pure
returns (uint256 feeAmount)
{
if (feePtr > 0) {
(, bytes32 value) = PackedRoute.next(route, feePtr);

(address pair, uint256 feePercent, uint256 flags, uint256 tokenInId, uint256 tokenOutId) =
PackedRoute.decode(value);

if ((uint256(uint160(pair)) | flags | tokenInId | tokenOutId) != 0) revert RouterLogic__InvalidFeeData();
if (feePercent == 0 || feePercent >= BPS) revert RouterLogic__InvalidFeePercent();

feeAmount = isSwapExactIn ? (amountIn * feePercent) / BPS : (amountIn * feePercent) / (BPS - feePercent);
}
}

/**
* @dev Helper function to return the amountIn for each swap in the route and the amountIn of the first token.
* The function will most likely revert if the same pair is used twice, or if the output of a pair is changed
Expand Down Expand Up @@ -314,4 +381,29 @@ contract RouterLogic is RouterAdapter, IRouterLogic {

return (token, amount);
}

/**
* @dev Helper function to send the fee to the fee receiver.
*/
function _sendFee(bytes calldata route, uint256 tokenInId, address from, uint256 feeAmount) private {
if (feeAmount > 0) {
address feeReceiver = _feeReceiver;
(address token,) = _transfer(route, tokenInId, from, feeReceiver, feeAmount);
emit FeeSent(token, from, feeReceiver, feeAmount);
}
}

/**
* @dev Helper function to set the fee receiver address.
*
* Requirements:
* - The fee receiver address must not be the zero address.
*/
function _setFeeReceiver(address feeReceiver) private {
if (feeReceiver == address(0)) revert RouterLogic__InvalidFeeReceiver();

_feeReceiver = feeReceiver;

emit FeeReceiverSet(feeReceiver);
}
}
7 changes: 7 additions & 0 deletions src/interfaces/IForwarderLogic.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,21 @@ interface IForwarderLogic {
error ForwarderLogic__RouterUpdateFailed();
error ForwarderLogic__UntrustedRouter();
error ForwarderLogic__Blacklisted();
error ForwarderLogic__InvalidFeeReceiver();

event TrustedRouterUpdated(address indexed router, bool trusted);
event BlacklistUpdated(address indexed account, bool blacklisted);
event FeeReceiverSet(address indexed feeReceiver);
event FeeSent(address indexed token, address indexed from, address indexed to, uint256 amount);

function getTrustedRouterLength() external view returns (uint256);

function getTrustedRouterAt(uint256 index) external view returns (address);

function isBlacklisted(address account) external view returns (bool);

function getFeeReceiver() external view returns (address);

function swapExactIn(
address tokenIn,
address tokenOut,
Expand All @@ -35,4 +40,6 @@ interface IForwarderLogic {
function updateTrustedRouter(address router, bool add) external;

function updateBlacklist(address account, bool blacklisted) external;

function setFeeReceiver(address feeReceiver) external;
}
Loading