pywavelets with jax

28 views
Skip to first unread message

Jakub Mitura

unread,
Feb 25, 2023, 10:31:12 AM2/25/23
to PyWavelets
Hello I need to use in my jax pipeline 3d stationary wavelet scattering and its inverse  - is it possible to use jax as a backend in pywavelets?

Deepu

unread,
May 22, 2023, 2:00:58 AM5/22/23
to PyWavelets
As of my knowledge cutoff in September 2021, PyWavelets does not directly support JAX as a backend. PyWavelets is primarily designed to work with NumPy, and it uses NumPy arrays for computations.

However, JAX provides a compatible subset of the NumPy API, which means you can potentially leverage JAX's compatibility with NumPy to perform wavelet transform computations. You may need to write additional code or wrappers to bridge the functionality between PyWavelets and JAX.

Here's a possible approach you can consider:

1. Use JAX for array manipulation: Utilize JAX arrays (`jnp.ndarray`) for data manipulation, instead of NumPy arrays. JAX arrays provide compatibility with the NumPy API, allowing you to perform mathematical operations and transformations.

2. Transform data with PyWavelets: Use PyWavelets to perform the desired 3D stationary wavelet scattering transform on the JAX arrays. You would need to convert the JAX arrays to NumPy arrays (`np.ndarray`) before passing them to PyWavelets functions.

3. Convert back to JAX: Once you have obtained the transformed data using PyWavelets, convert the resulting NumPy arrays back to JAX arrays (`jnp.ndarray`) to continue your pipeline using JAX.

This approach requires careful conversion between NumPy and JAX arrays to ensure compatibility and consistency throughout the pipeline. Additionally, keep in mind that the performance characteristics of JAX and NumPy might differ, and some optimizations specific to JAX might not be available when using PyWavelets.

It's worth noting that JAX provides its own ecosystem of libraries and functions for numerical computations and signal processing. Depending on your specific requirements, you might consider exploring JAX-specific libraries for wavelet transforms and scattering transforms, which can provide direct integration with JAX and potentially offer additional benefits such as autodifferentiation.

Please consult the JAX documentation and available libraries to explore the possibilities for implementing 3D stationary wavelet scattering and inverse transforms within a JAX pipeline.
Reply all
Reply to author
Forward
0 new messages