Supported Operations

diffstruc supports a wide range of mathematical operations, all with automatic differentiation capabilities.

Below is a table of the supported operations, along with notes on how to build your own custom operations if needed.

If there are operations you need that are not listed here and would be useful to others, please consider contributing them via a pull request on the diffstruc GitHub repository.

Operation Summary Table

Category

Operations

Notes

Arithmetic

+, -, *, /, **

Supports scalars and arrays

Trigonometric

sin, cos, tan

Input in radians, same as Fortran intrinsic

Hyperbolic

tanh

Common in neural networks

Exponential

exp, log

log is natural logarithm

Linear Algebra

matmul, outer_product, transpose

Matrix operations

Reduction

sum, mean, unspread, max, maxval

Collapse dimensions

Broadcast

spread, concat, ltrim, rtrim, index, reverse_index

Broadcasting and indexing

Other

sign, sqrt, sigmoid, gaussian

Element-wise operations

Custom Operations

If you need an operation not provided, you can implement it providing a custom Fortran function and defining the partial derivative procedures. The custom function can take any form you need. However, the partial derivative functions must conform to the interface expected by diffstruc.

The interface for the get_partial_left and get_partial_right functions are:

interface
   module function get_partial(this, upstream_grad) result(output)
     class(array_type), intent(inout) :: this
     type(array_type), intent(in) :: upstream_grad
     type(array_type) :: output
   end function get_partial
end interface

Depending on the operation, you might only need to define one of these (priority is given to the left operand if only one is defined).

A simple example of a custom operation that takes in one operand and computes the cosine is shown below. Focus on the parts marked with comments.

module custom_operations
  use diffstruc
  implicit none

  interface operation_name
    module procedure my_custom_op
  end interface

contains

  function my_custom_op(a) result(c)
    !! This is a custom operation example, it can take any form you need.
    implicit none
    class(array_type), intent(in), target :: a
    type(array_type), pointer :: c

    !! Allocates result array to the same shape as input
    c => a%create_result()
    ! !! An alternative is to provide it with a different shape, do not forget the batch_size final dimension
    ! c => a%create_result([desired_shape])

    !!-----------------------------------------------
    !! YOUR CUSTOM OPERATION
    !!-----------------------------------------------
    c%val = cos(a%val)
    !!-----------------------------------------------

    c%get_partial_left => get_partial_left_custom_op
    if(a%requires_grad) then
      c%requires_grad = .true.
      c%is_forward = a%is_forward
      c%operation = 'cos'
      c%left_operand => a
      c%owns_left_operand = a%is_temporary
    end if
  end function my_custom_op

  function get_partial_left_custom_op(this, upstream_grad) result(output)
    !! Partial derivative of custom operation w.r.t. left operand
    !! This has to conform to the interface expected by diffstruc
    implicit none
    class(array_type), intent(inout) :: this
    type(array_type), intent(in) :: upstream_grad
    type(array_type) :: output

    logical :: left_is_temporary_local
    type(array_type), pointer :: ptr

    !! Save and temporarily disable the temporary status of the left operand
    left_is_temporary_local = this%left_operand%is_temporary
    this%left_operand%is_temporary = .false.

    !!-----------------------------------------------
    !! YOUR CUSTOM PARTIAL DERIVATIVE
    !!-----------------------------------------------
    ptr => -upstream_grad * sin( this%left_operand )
    !!-----------------------------------------------

    !! Restore the temporary status of the left operand
    this%left_operand%is_temporary = left_is_temporary_local

    call output%assign_and_deallocate_source(ptr)
  end function get_partial_left_custom_op

end module custom_operations

For one with two operands, you would similarly define get_partial_right_custom_op and associate the right_operand pointer of the result. For how this works, see the built-in matmul operation in the source code (diffstruc_operations_linalg_sub.f90)