Skip to content

Make convolve mode symbolic to avoid unnecessary large convolution in gradient graph #1522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jul 4, 2025

Instead of statically parametrizing the convolution type at the Op level, it now uses a scalar boolean that can be set symbolically. This was fine except for JAX, where we can't branch symbolically like that.

As long as the mode is constant (which is the case when the original convolve was full, or the core shapes are statically known) this should be fine, except for one hiccup. In the dispatch of Blockwise we couldn't see the outer inputs.

I added some functionality when we create the dummy core node to propagate inputs if these don't have batch dimensions, which means they won't change over iterations and so making compile or rewrite decisions based on this should be safe. This can also be used for infer_shape for instance, which could help with lowering some Ops to numba

These changes would also allow us to compile a constant convolve mode in C/Numba, but benchmarks didn't show any gains so I didn't bother doing that. In any case, future implementations of Blockwise for certain Ops can make use of that information.

Relevant benchmark tests:

Before:
--------------------------------------------------------------------------- benchmark: 8 tests ---------------------------------------------------------------------------
Name (time in us)                                                   Min                    Max                   Mean              StdDev                 Median
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_convolve1d_benchmark_numba[batch=False-mode=full]           6.3820 (1.06)        117.5200 (1.72)          6.7371 (1.06)       1.0520 (1.27)          6.5720 (1.05)
test_convolve1d_benchmark_numba[batch=False-mode=valid]          6.0410 (1.0)          81.1920 (1.19)          6.3736 (1.0)        0.8259 (1.0)           6.2410 (1.0)
test_convolve1d_benchmark_numba[batch=True-mode=full]           14.0660 (2.33)      1,015.9230 (14.87)        14.6380 (2.30)       5.4193 (6.56)         14.3370 (2.30)
test_convolve1d_benchmark_numba[batch=True-mode=valid]          12.8040 (2.12)         68.3380 (1.0)          13.4789 (2.11)       1.4135 (1.71)         13.0740 (2.09)
test_convolve1d_grad_benchmark_numba[full]                     176.2610 (29.18)       190.4460 (2.79)        181.4098 (28.46)      6.4607 (7.82)        177.5830 (28.45)
test_convolve1d_grad_benchmark_numba[valid]                 10,654.1110 (>1000.0)  10,663.5190 (156.04)   10,658.8200 (>1000.0)    3.5995 (4.36)     10,658.2390 (>1000.0)
test_convolve1d_grad_benchmark_c[full]                          99.6770 (16.50)       188.2430 (2.75)        107.0415 (16.79)      8.2845 (10.03)       103.7550 (16.62)
test_convolve1d_grad_benchmark_c[valid]                      1,945.9860 (322.13)    3,263.2340 (47.75)     2,070.7502 (324.90)   165.5805 (200.49)    1,991.7020 (319.13)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------

After:
--------------------------------------------------------------------- benchmark: 8 tests --------------------------------------------------------------------
Name (time in us)                                                Min                 Max                Mean             StdDev              Median
-------------------------------------------------------------------------------------------------------------------------------------------------------------
test_convolve1d_benchmark_numba[batch=False-mode=full]        6.4320 (1.05)     121.2880 (1.92)       7.0089 (1.08)      1.4235 (1.68)       6.6420 (1.05)
test_convolve1d_benchmark_numba[batch=False-mode=valid]       6.1220 (1.0)       63.2990 (1.0)        6.4999 (1.0)       0.8496 (1.0)        6.3120 (1.0)
test_convolve1d_benchmark_numba[batch=True-mode=full]        14.5480 (2.38)     133.0200 (2.10)      16.1110 (2.48)      2.5351 (2.98)      14.8580 (2.35)
test_convolve1d_benchmark_numba[batch=True-mode=valid]       13.4050 (2.19)      82.8950 (1.31)      14.2643 (2.19)      1.8216 (2.14)      13.7060 (2.17)
test_convolve1d_grad_benchmark_numba[full]                  177.7730 (29.04)    197.0690 (3.11)     183.0290 (28.16)     8.1456 (9.59)     179.0460 (28.37)
test_convolve1d_grad_benchmark_numba[valid]                 175.9900 (28.75)    184.9760 (2.92)     180.1314 (27.71)     4.4441 (5.23)     177.9230 (28.19)
test_convolve1d_grad_benchmark_c[full]                      107.2810 (17.52)    781.7850 (12.35)    122.0496 (18.78)    24.0425 (28.30)    112.4210 (17.81)
test_convolve1d_grad_benchmark_c[valid]                     115.8170 (18.92)    401.0020 (6.34)     127.0086 (19.54)    19.5353 (22.99)    120.4860 (19.09)

Note the worst case scenario when we were doing a full convolution for the smaller input in the gradient of a valid convolution.


📚 Documentation preview 📚: https://pytensor--1522.org.readthedocs.build/en/1522/

@ricardoV94 ricardoV94 force-pushed the make_convolve_mode_symbolic branch 3 times, most recently from 48a5ce3 to 95b4cb3 Compare July 6, 2025 19:19
@ricardoV94 ricardoV94 force-pushed the make_convolve_mode_symbolic branch 3 times, most recently from 281bbf9 to fe2ea6b Compare July 7, 2025 11:21
@ricardoV94 ricardoV94 force-pushed the make_convolve_mode_symbolic branch from fe2ea6b to 12e3123 Compare July 7, 2025 11:38
@ricardoV94 ricardoV94 force-pushed the make_convolve_mode_symbolic branch from 12e3123 to 5be6968 Compare July 7, 2025 11:54
Copy link

codecov bot commented Jul 7, 2025

Codecov Report

Attention: Patch coverage is 68.42105% with 48 lines in your changes missing coverage. Please review.

Project coverage is 82.04%. Comparing base (7584614) to head (5be6968).

Files with missing lines Patch % Lines
pytensor/link/numba/dispatch/signal/conv.py 16.66% 35 Missing ⚠️
pytensor/tensor/signal/conv.py 80.48% 4 Missing and 4 partials ⚠️
pytensor/tensor/blockwise.py 93.75% 0 Missing and 3 partials ⚠️
pytensor/link/jax/dispatch/signal/conv.py 80.00% 2 Missing ⚠️

❌ Your patch check has failed because the patch coverage (68.42%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1522      +/-   ##
==========================================
- Coverage   82.04%   82.04%   -0.01%     
==========================================
  Files         231      230       -1     
  Lines       52364    52345      -19     
  Branches     9217     9212       -5     
==========================================
- Hits        42962    42946      -16     
- Misses       7094     7095       +1     
+ Partials     2308     2304       -4     
Files with missing lines Coverage Δ
pytensor/link/jax/dispatch/blockwise.py 100.00% <100.00%> (ø)
pytensor/link/numba/dispatch/blockwise.py 90.00% <100.00%> (ø)
pytensor/tensor/basic.py 91.69% <100.00%> (-0.02%) ⬇️
pytensor/tensor/rewriting/blockwise.py 96.55% <100.00%> (+0.47%) ⬆️
pytensor/tensor/rewriting/subtensor_lift.py 92.28% <100.00%> (+0.51%) ⬆️
pytensor/link/jax/dispatch/signal/conv.py 87.50% <80.00%> (-12.50%) ⬇️
pytensor/tensor/blockwise.py 89.31% <93.75%> (+0.04%) ⬆️
pytensor/tensor/signal/conv.py 87.50% <80.48%> (-7.63%) ⬇️
pytensor/link/numba/dispatch/signal/conv.py 32.69% <16.66%> (+0.69%) ⬆️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant