2

I have a number of operations I want to "fuse" together. Let's say there are 3 possible operations:

sq = lambda x: x**2
add = lambda x: x+3
mul = lambda x: x*5

I also have an array of operations:

ops = [add, sq, mul, sq]

I can then create a function from these operations:

def generateF(ops):
    def inner(x):
        for op in ops:
            x = op(x)
        return x
    return inner
f = generateF(ops)
f(3) # returns 32400

fastF = lambda x: (5*(x+3)**2)**2

f and fastF does the same thing, but fastF is around 1.7-2 times faster than f on my benchmark, which makes sense. My question is, how can I write generateF function that returns a function that is as fast as fastF? The operations are restricted to basic operations like __add__, __mul__, __matmul__, __rrshift__, etc (essentially most numeric operations). generateF can take as long as you'd like, because it will be done before reaching hot code.

The context is that this is a part of my library, so I can define every legal operation, and thus know exactly what they are. The operation definitions are not given to us by the end user randomly (the user can only pick the order of the operations), so we can utilize every outside knowledge about them.

This might seem like premature optimization, but it is not, as f is hot code. Dropping to C is not an option, as the operations can be complex (think, PyTorch tensor multiply), and x can be of any type. Currently, I'm thinking about modifying python's bytecode, but that is very unpleasant, as bytecode specifications changes for every Python version, so I wanted to ask here first before diving into that solution.

11
  • 2
    I don't think you can do this. There's no way to extract the body of a function and merge it into a new function, so you're always going to get the overhead of calling 3 functions. Commented Nov 12, 2021 at 21:14
  • 3
    You could get exactly the same result as fastF by generating Python source code and then calling exec()/eval()/compile() on it. Your individual operations would become string manipulations: sq = lambda x: f"({x})**2" for example. Start with the innermost x actually being "x", then put "lambda x:" in front of the result. Commented Nov 12, 2021 at 21:16
  • 1
    If some of the operations are super long (Pytorch vector ops), then is the slow step really the extra function call overhead? Either the function call overhead is the slowest thing happening (in which case, dip into C and accelerate it), or you're running slow Python functions (in which case, the extra calls aren't killing you) Commented Nov 12, 2021 at 21:17
  • 1
    @SilvioMayolo Yes I understand that. I included in PyTorch tensor ops just to be consistent with everything. f will mostly operate on random generic objects, which includes PyTorch tensors, but it will mostly operate on lightweight objects, so function call overhead here is a lot. Commented Nov 12, 2021 at 21:20
  • 2
    Beside the point, but named lambdas are bad practice. Use a def instead. I'm pretty sure they're equally performant. Commented Nov 12, 2021 at 21:25

3 Answers 3

4

Here is a very hacky version of synthesizing a new function from the bytecode of the given functions. The basic technique is to keep the LOAD_FAST opcode only at the beginning of the first function, and strip off the RETURN_VALUE opcode except at the end of the last function. This leaves the value being manipulated on the stack in between (what were originally) your functions. When you're done, you don't have any function calls.

import dis, inspect

sq = lambda x: x**2
add = lambda x: x+3
mul = lambda x: x*5

ops = [add, sq, mul, sq]

def synthF(ops):
    bytecode = bytearray()
    constants = []
    stacksize = 0
    for i, op in enumerate(ops):
        code = op.__code__
        # works only with functions having one argument and no other vars
        assert code.co_argcount == code.co_nlocals == 1
        assert not code.co_freevars
        stacksize = max(stacksize, code.co_stacksize)
        opcodes = bytearray(code.co_code)
        # starts with LOAD_FAST argument 0 (i.e. we're doing something with our arg)
        assert opcodes[0] == dis.opmap["LOAD_FAST"] and opcodes[1] == 0
        # ends with RETURN_VALUE
        assert opcodes[-2] == dis.opmap["RETURN_VALUE"] and opcodes[-1] == 0
        if bytecode:        # if this isn't our first function, our variable is already on the stock
            opcodes = opcodes[2:]
        # adjust LOAD_CONSTANT opcodes. each function can have constants,
        # but their indexes start at 0 in each function.  since we're
        # putting these into a single function we need to accumulate the
        # constants used in each function and adjust the indexes used in
        # the function's bytecode to access the value by its index in the
        # accumulated list.
        offset = 0
        if bytecode:
            while True:
                none = code.co_consts[0] is None
                offset = opcodes.find(dis.opmap["LOAD_CONST"], offset)
                if offset < 0:
                    break
                if not offset % 2 and (not none or opcodes[offset+1]):
                    opcodes[offset+1] += len(constants) - none
                offset += 2
            # first constant is always None. don't include multiple copies
            # (to be safe, we actually check that)
            constants.extend(code.co_consts[none:])
        else:
            assert code.co_consts[0] is None
            constants.extend(code.co_consts)
        # add our adjusted bytecode, cutting off the RETURN_VALUE opcode
        bytecode.extend(opcodes[:-2])
    bytecode.extend([dis.opmap["RETURN_VALUE"], 0])

    func = type(ops[0])(type(code)(1, 1, 0, 1, stacksize, inspect.CO_OPTIMIZED, bytes(bytecode),
                tuple(constants), (), ("x",), "<generated>", "<generated>", 0, b''),
                globals())

    return func

f = synthF(ops)
assert f(3) == 32400

Gross, and lots of caveats (called out in comments) but it works, and should be about as fast as your expression, since it compiles to virtually the same bytecode. It would need a bit of work to support concatenating more complex functions.

Sign up to request clarification or add additional context in comments.

1 Comment

Thank you so much for spending the time! Will definitely check it out.
1

Here's an alternative using chaining. This way, there's only function calls in your generated function calls, no iteration.

def makeF(ops):
    f = ops[0]
    for op in ops[1:]:
        f = lambda x, op=op, f=f: op(f(x))
    return f

Bad news: it replaces each function call with two, so it's actually slower than your iterative version. :/

1 Comment

Can confirm. It's very slightly slower than my iterative version (1.02x slower). But thanks for suggesting it nonetheless
0

As there seems to be no solution, this is what I've settled on. Knowing that most operations will be short (<4 operations total), I just hard code in to get rid of the for loop.

def generateF(ops):
    l = len(ops)
    if l == 1:
        return ops[0]
    if l == 2:
        a, b = ops
        return lambda x: b(a(x))
    if l == 3:
        a, b, c = ops
        return lambda x: c(b(a(x)))
    if l == 4:
        a, b, c, d = ops
        return lambda x: d(c(b(a(x))))
    def inner(x):
        for op in ops:
            x = op(x)
        return x
    return inner
fastF = generateF(ops)

This is only 1.4x slower than fastF (originally 1.7-2x slower). If you have any other ideas, I will consider it.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.