1

I am using both python numpy and the jax.numpy replacement for numpy and I often need some user-defined function for both. The issue is that jax.numpy is such a good drop-in replacement that very often I'm essentially writing the same function for both libraries. So to give a very simple example I might have:

module_numpy.py:

import numpy as np

def get_loss(x):
    return np.square(x) + 1

module_jax.py:

import jax.numpy as jnp

def get_loss(x):
    return jnp.square(x) + 1

where I basically write the same code twice, and the problem is that if I make a change to either one of these functions, I have to make the same change to both, and it becomes a lot to maintain. I may come back to functions I've changed later and not realize which one is the most "correct" or more recently updated.

Question: How can I essentially redefine a library by replacing the global variables under which module_numpy.get_loss is defined? For instance, I have tried:

module_jax_auto.py:

import jax.numpy as jnp
import module_numpy
import inspect

f = eval(inspect.source(module_numpy.get_loss).replace('np', 'jnp'))

but python complains with a syntax error. Apparently I can't used def inside eval, but even if I could, I'm not sure this is the "best" way solve the issue of maintain 2 parallel libraries. But I do always want to use the numpy version to define the jax version, and not the other way around.

It's worth nothing that there are some cases where I do need to define the jax version of the function differently, so I will exclude some functions from being "auto-generated" from the numpy version.

What's the best approach here?

4
  • Can you define all your functions in a class and then pass np as a parameter or define a global within a module which is assigned to the version of numpy outside the module Commented Jan 6, 2022 at 16:51
  • Defining all these functions feels like a hack to me because the vast majority of this code is stateless. Defining an external global doesn't really work if I'm understanding your comment correctly, since I need to be able to use both functions in the same runtime and not just one or the other. Commented Jan 6, 2022 at 17:02
  • If the code for both libraries the same or is it just the function names? Commented Jan 6, 2022 at 19:16
  • Yes they are in vast majority of cases the same exact code as in the example above, just with jnp instead of np. Commented Jan 6, 2022 at 21:37

1 Answer 1

1

Not sure what you mean by it feels like a hack. If it gets you to where you want, does it matter whether it is a hack or not.

import numpy as np
import jax.numpy as jnp

# Replecated functions
class ncommon:
    def __init__(me, np):
        me.np = np

    def get_loss(me, x):
        return me.np.square(x) + 1

# Specific function for np
class SNP(ncommon):
    def __init__(me):
        ncommon.__init__(me, np)

    def get_profit(me, x):
        return me.np.square(x) + 2

# Specific functions for jnp
class JNP(ncommon):
    def __init__(me):
        ncommon.__init__(me, jnp)

    def get_profit(me, x):
        return me.np.square(x) + 3

if __name__ == '__main__':
    snpy = SNP()
    jnpy = JNP()
    print(f'common   {snpy.get_loss(4)}    {jnpy.get_loss(4)}')
    print(f'specific {snpy.get_profit(4)}  {jnpy.get_profit(4)}')

Alternatively, if you are dead against using classes, just pass in the version of numpy as a parameter.

def get_loss(np, x):
    return np.square(x) + 1

print(f'std {get_loss(np, 5)}  jax {get_loss(jnp, 5)}')
Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.