I am using both python numpy and the jax.numpy replacement for numpy and I often need some user-defined function for both. The issue is that jax.numpy is such a good drop-in replacement that very often I'm essentially writing the same function for both libraries. So to give a very simple example I might have:
module_numpy.py:
import numpy as np
def get_loss(x):
return np.square(x) + 1
module_jax.py:
import jax.numpy as jnp
def get_loss(x):
return jnp.square(x) + 1
where I basically write the same code twice, and the problem is that if I make a change to either one of these functions, I have to make the same change to both, and it becomes a lot to maintain. I may come back to functions I've changed later and not realize which one is the most "correct" or more recently updated.
Question: How can I essentially redefine a library by replacing the global variables under which module_numpy.get_loss is defined? For instance, I have tried:
module_jax_auto.py:
import jax.numpy as jnp
import module_numpy
import inspect
f = eval(inspect.source(module_numpy.get_loss).replace('np', 'jnp'))
but python complains with a syntax error. Apparently I can't used def inside eval, but even if I could, I'm not sure this is the "best" way solve the issue of maintain 2 parallel libraries. But I do always want to use the numpy version to define the jax version, and not the other way around.
It's worth nothing that there are some cases where I do need to define the jax version of the function differently, so I will exclude some functions from being "auto-generated" from the numpy version.
What's the best approach here?
jnpinstead ofnp.