This repository contains a 2D Navier-Stokes equation solver and data processing methods. The solver, written using the JAX library, is computationally expensive and leverages GPU acceleration. The less intensive post-processing methods use NumPy.
Some functions are required by both the solver and post-processing. Currently, duplicate copies exist – one for JAX and one for NumPy. What is the best way to optimize this code structure?
- Passing a backend variable
backend = 'numpy' or 'jax': Suitable for functions where the underlying structure is identical between NumPy and JAX, with only the library calls differing
- Write all the functions in JAX (Need to think over it)
Pros:
- User friendliness
Cons:
- Requires careful consideration of JAX-specific function implementations.