Skip to content

JAX and Numpy functions: How to structure them? #58

@jakharkaran

Description

@jakharkaran

This repository contains a 2D Navier-Stokes equation solver and data processing methods. The solver, written using the JAX library, is computationally expensive and leverages GPU acceleration. The less intensive post-processing methods use NumPy.

Some functions are required by both the solver and post-processing. Currently, duplicate copies exist – one for JAX and one for NumPy. What is the best way to optimize this code structure?

  • Passing a backend variable backend = 'numpy' or 'jax': Suitable for functions where the underlying structure is identical between NumPy and JAX, with only the library calls differing
  • Write all the functions in JAX (Need to think over it)
    Pros:
    • User friendliness
      Cons:
    • Requires careful consideration of JAX-specific function implementations.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or requestquestionFurther information is requested

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions