The PyMC MAP estimate is a thin wrapper around `scipy.minimize` to optimize your log probability (using L-BFGS-B, unless you have discrete variables, then it uses Powell).
Do you have a reason to believe 28 non-zero columns is too many?
If you are using version 4 of PyMC, you might use
```
from pymc import sampling_jax
with pm.Model():
... define your model...
jax_logp = sampling_jax.get_jaxified_logp()
```
Then look at a library like
JAXopt to have more control over the optimization of the log probability.