2

I've been using py4j to build a user-friendly Python library around a less user-friendly Java library. For the most part, this has been a breeze, and py4j has been a great tool. However, I've come across a snag when sending matrices between Python and Java.

Specifically, I have a static function in java that accepts, as its arguments, an integer matrix:

public class MyClass {
   // ...
   public static MyObject create(int[][] matrix) {
      // ...
   }
}

I'd like to be able to call this from Py4j like so:

def create_java_object(numpy_matrix):
   # <code here checks that numpy_matrix is a (3 x n) integer matrix>
   # ...
   return java_instance.jvm.my.namespace.MyClass.create(numpy_matrix)

This doesn't work, which isn't too surprising, nor does it work if the numpy_matrix is instead converted to a list of plain python lists. I had expected that the solution would be to construct a java array and transfer the data over prior to the function call:

def create_java_object(numpy_matrix):
   # <code here checks that numpy_matrix is a (3 x n) integer matrix>
   # ...
   java_matrix = java_instance.new_array(java_instance.jvm.int, 3, n)
   for i in range(numpy_matrix.shape[1]):
      java_matrix[0][i] = int(numpy_matrix[0, i])
      java_matrix[1][i] = int(numpy_matrix[1, i])
      java_matrix[2][i] = int(numpy_matrix[2, i])
   return java_instance.jvm.my.namespace.MyClass.create(java_matrix)

Now, this code runs correctly. However, it requires approximately two minutes to run. The matrices I'm working with, by the way, are on the order of (3 x ~300,000) elements.

Is there a canonical way to do this in Py4j that doesn't require incredible amounts of time just to convert a matrix? I don't mind it taking a second or two, but this is far too slow. If Py4j isn't setup for this kind of communication, is there a Java interop library for Python that is?

Note: The Java library treats the int[][] matrix as an immutable array; i.e., it never attempts to modify it.

1 Answer 1

4

I found a solution for this particular case that works; though it is not terribly elegant:

Py4j supports efficiently passing a Python bytearray object to Java as a byte[] array. I worked around the problem by modifying the original library and my Python code.

The new Java code:

public class MyClass {
   // ...
   public static MyObject create(int[][] matrix) {
      // ...
   }
   public static MyObject createFromPy4j(byte[] data) {
      java.nio.ByteBuffer buf = java.nio.ByteBuffer.wrap(data);
      int n = buf.getInt(), m = buf.getInt();
      int[][] matrix = new int[n][m];
      for (int i = 0; i < n; ++i)
         for (int j = 0; j < m; ++j)
            matrix[i][j] = buf.getInt();
      return MyClass.create(matrix);
   }
}

The new Python code:

def create_java_object(numpy_matrix):
   header = array.array('i', list(numpy_matrix.shape))
   body = array.array('i', numpy_matrix.flatten().tolist());
   if sys.byteorder != 'big':
      header.byteswap()
      body.byteswap()
   buf = bytearray(header.tostring() + body.tostring())
   return java_instance.jvm.my.namespace.MyClass.createFromPy4j(buf)

This runs in a few seconds rather than a few minutes.

Sign up to request clarification or add additional context in comments.

1 Comment

This would only work for 2D matrices? How about storing the shape length as the first value so that you can send a matrix of any shape?

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.