diff --git a/doc/specs/stdlib_intrinsics.md b/doc/specs/stdlib_intrinsics.md index be23adeb5..971a55ecd 100644 --- a/doc/specs/stdlib_intrinsics.md +++ b/doc/specs/stdlib_intrinsics.md @@ -155,4 +155,42 @@ The output is a scalar of the same type and kind as to that of `x` and `y`. ```fortran {!example/intrinsics/example_dot_product.f90!} -``` \ No newline at end of file +``` + +### `stdlib_matmul` function + +#### Description + +The extension of the intrinsic function `matmul` to handle more than 2 and less than or equal to 5 matrices, with error handling using `linalg_state_type`. +The optimal parenthesization to minimize the number of scalar multiplications is done using the Algorithm as outlined in Cormen, "Introduction to Algorithms", 4ed, ch-14, section-2. +The actual matrix multiplication is performed using the `gemm` interfaces. +It supports only `real` and `complex` matrices. + +#### Syntax + +`res = ` [[stdlib_intrinsics(module):stdlib_matmul(interface)]] ` (m1, m2, m3, m4, m5, err)` + +#### Status + +Experimental + +#### Class + +Function. + +#### Argument(s) + +`m1`, `m2`: 2D arrays of the same kind and type. `intent(in)` arguments. +`m3`,`m4`,`m5`: 2D arrays of the same kind and type as the other matrices. `intent(in), optional` arguments. +`err`: `type(linalg_state_type), intent(out), optional` argument. Can be used for elegant error handling. It is assigned `LINALG_VALUE_ERROR` + in case the matrices are not of compatible sizes. + +#### Result + +The output is a matrix of the appropriate size. + +#### Example + +```fortran +{!example/intrinsics/example_matmul.f90!} +``` diff --git a/example/intrinsics/CMakeLists.txt b/example/intrinsics/CMakeLists.txt index 1645ba8a1..162744b66 100644 --- a/example/intrinsics/CMakeLists.txt +++ b/example/intrinsics/CMakeLists.txt @@ -1,2 +1,3 @@ ADD_EXAMPLE(sum) -ADD_EXAMPLE(dot_product) \ No newline at end of file +ADD_EXAMPLE(dot_product) +ADD_EXAMPLE(matmul) diff --git a/example/intrinsics/example_matmul.f90 b/example/intrinsics/example_matmul.f90 new file mode 100644 index 000000000..b62215074 --- /dev/null +++ b/example/intrinsics/example_matmul.f90 @@ -0,0 +1,12 @@ +program example_matmul + use stdlib_intrinsics, only: stdlib_matmul + complex :: x(2, 2), y(2, 2), z(2, 2) + x = reshape([(0, 0), (1, 0), (1, 0), (0, 0)], [2, 2]) + y = reshape([(0, 0), (0, 1), (0, -1), (0, 0)], [2, 2]) ! pauli y-matrix + z = reshape([(1, 0), (0, 0), (0, 0), (-1, 0)], [2, 2]) + + print *, stdlib_matmul(x, y) ! should be iota*z + print *, stdlib_matmul(y, z, x) ! should be iota*identity + print *, stdlib_matmul(x, x, z, y) ! should be -iota*x + print *, stdlib_matmul(x, x, z, y, y) ! should be z +end program example_matmul diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c3cd99120..5e915dfed 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -19,6 +19,7 @@ set(fppFiles stdlib_hash_64bit_spookyv2.fypp stdlib_intrinsics_dot_product.fypp stdlib_intrinsics_sum.fypp + stdlib_intrinsics_matmul.fypp stdlib_intrinsics.fypp stdlib_io.fypp stdlib_io_npy.fypp @@ -32,14 +33,14 @@ set(fppFiles stdlib_linalg_kronecker.fypp stdlib_linalg_cross_product.fypp stdlib_linalg_eigenvalues.fypp - stdlib_linalg_solve.fypp + stdlib_linalg_solve.fypp stdlib_linalg_determinant.fypp stdlib_linalg_qr.fypp stdlib_linalg_inverse.fypp stdlib_linalg_pinv.fypp stdlib_linalg_norms.fypp stdlib_linalg_state.fypp - stdlib_linalg_svd.fypp + stdlib_linalg_svd.fypp stdlib_linalg_cholesky.fypp stdlib_linalg_schur.fypp stdlib_optval.fypp diff --git a/src/stdlib_intrinsics.fypp b/src/stdlib_intrinsics.fypp index b2c16a5a6..4b5751df6 100644 --- a/src/stdlib_intrinsics.fypp +++ b/src/stdlib_intrinsics.fypp @@ -8,6 +8,7 @@ module stdlib_intrinsics !!Alternative implementations of some Fortran intrinsic functions offering either faster and/or more accurate evaluation. !! ([Specification](../page/specs/stdlib_intrinsics.html)) use stdlib_kinds + use stdlib_linalg_state, only: linalg_state_type implicit none private @@ -146,6 +147,49 @@ module stdlib_intrinsics #:endfor end interface public :: kahan_kernel + + interface stdlib_matmul + !! version: experimental + !! + !!### Summary + !! compute the matrix multiplication of more than two matrices with a single function call. + !! ([Specification](../page/specs/stdlib_intrinsics.html#stdlib_matmul)) + !! + !!### Description + !! + !! matrix multiply more than two matrices with a single function call + !! the multiplication with the optimal parenthesization for efficiency of computation is done automatically + !! Supported data types are `real` and `complex`. + !! + !! Note: The matrices must be of compatible shapes to be multiplied + #:for k, t, s in R_KINDS_TYPES + C_KINDS_TYPES + pure module function stdlib_matmul_pure_${s}$ (m1, m2, m3, m4, m5) result(r) + ${t}$, intent(in) :: m1(:,:), m2(:,:) + ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:) + ${t}$, allocatable :: r(:,:) + end function stdlib_matmul_pure_${s}$ + + module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5, err) result(r) + ${t}$, intent(in) :: m1(:,:), m2(:,:) + ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:) + type(linalg_state_type), intent(out) :: err + ${t}$, allocatable :: r(:,:) + end function stdlib_matmul_${s}$ + #:endfor + end interface stdlib_matmul + public :: stdlib_matmul + + ! internal interface + interface stdlib_matmul_sub + #:for k, t, s in R_KINDS_TYPES + C_KINDS_TYPES + pure module subroutine stdlib_matmul_sub_${s}$ (res, m1, m2, m3, m4, m5, err) + ${t}$, intent(out), allocatable :: res(:,:) + ${t}$, intent(in) :: m1(:,:), m2(:,:) + ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:) + type(linalg_state_type), intent(out), optional :: err + end subroutine stdlib_matmul_sub_${s}$ + #:endfor + end interface stdlib_matmul_sub contains diff --git a/src/stdlib_intrinsics_matmul.fypp b/src/stdlib_intrinsics_matmul.fypp new file mode 100644 index 000000000..c24f95374 --- /dev/null +++ b/src/stdlib_intrinsics_matmul.fypp @@ -0,0 +1,288 @@ +#:include "common.fypp" +#:set I_KINDS_TYPES = list(zip(INT_KINDS, INT_TYPES, INT_KINDS)) +#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX)) +#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX)) + +submodule (stdlib_intrinsics) stdlib_intrinsics_matmul + use stdlib_linalg_blas, only: gemm + use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_VALUE_ERROR, LINALG_INTERNAL_ERROR + use stdlib_constants + implicit none + + character(len=*), parameter :: this = "stdlib_matmul" + +contains + + ! Algorithm for the optimal parenthesization of matrices + ! Reference: Cormen, "Introduction to Algorithms", 4ed, ch-14, section-2 + ! Internal use only! + pure function matmul_chain_order(p) result(s) + integer, intent(in) :: p(:) + integer :: s(1:size(p) - 2, 2:size(p) - 1), m(1:size(p) - 1, 1:size(p) - 1) + integer :: n, l, i, j, k, q + n = size(p) - 1 + m(:,:) = 0 + s(:,:) = 0 + + do l = 2, n + do i = 1, n - l + 1 + j = i + l - 1 + m(i,j) = huge(1) + + do k = i, j - 1 + q = m(i,k) + m(k+1,j) + p(i)*p(k+1)*p(j+1) + + if (q < m(i, j)) then + m(i,j) = q + s(i,j) = k + end if + end do + end do + end do + end function matmul_chain_order + +#:for k, t, s in R_KINDS_TYPES + C_KINDS_TYPES + + pure function matmul_chain_mult_${s}$_3 (m1, m2, m3, start, s, p) result(r) + ${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:) + integer, intent(in) :: start, s(:,2:), p(:) + ${t}$, allocatable :: r(:,:), temp(:,:) + integer :: ord, m, n, k + ord = s(start, start + 2) + allocate(r(p(start), p(start + 3))) + + if (ord == start) then + ! m1*(m2*m3) + m = p(start + 1) + n = p(start + 3) + k = p(start + 2) + allocate(temp(m,n)) + call gemm('N', 'N', m, n, k, one_${s}$, m2, m, m3, k, zero_${s}$, temp, m) + m = p(start) + n = p(start + 3) + k = p(start + 1) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m) + else if (ord == start + 1) then + ! (m1*m2)*m3 + m = p(start) + n = p(start + 2) + k = p(start + 1) + allocate(temp(m, n)) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m) + m = p(start) + n = p(start + 3) + k = p(start + 1) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m3, k, zero_${s}$, r, m) + else + ! our internal functions are incorrent, abort + error stop this//": error: unexpected s(i,j)" + end if + + end function matmul_chain_mult_${s}$_3 + + pure function matmul_chain_mult_${s}$_4 (m1, m2, m3, m4, start, s, p) result(r) + ${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:), m4(:,:) + integer, intent(in) :: start, s(:,2:), p(:) + ${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:) + integer :: ord, m, n, k + ord = s(start, start + 3) + allocate(r(p(start), p(start + 4))) + + if (ord == start) then + ! m1*(m2*m3*m4) + temp = matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s, p) + m = p(start) + n = p(start + 4) + k = p(start + 1) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m) + else if (ord == start + 1) then + ! (m1*m2)*(m3*m4) + m = p(start) + n = p(start + 2) + k = p(start + 1) + allocate(temp(m,n)) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m) + + m = p(start + 2) + n = p(start + 4) + k = p(start + 3) + allocate(temp1(m,n)) + call gemm('N', 'N', m, n, k, one_${s}$, m3, m, m4, k, zero_${s}$, temp1, m) + + m = p(start) + n = p(start + 4) + k = p(start + 2) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m) + else if (ord == start + 2) then + ! (m1*m2*m3)*m4 + temp = matmul_chain_mult_${s}$_3(m1, m2, m3, start, s, p) + m = p(start) + n = p(start + 4) + k = p(start + 3) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m4, k, zero_${s}$, r, m) + else + ! our internal functions are incorrent, abort + error stop this//": error: unexpected s(i,j)" + end if + + end function matmul_chain_mult_${s}$_4 + + pure module subroutine stdlib_matmul_sub_${s}$ (res, m1, m2, m3, m4, m5, err) + ${t}$, intent(out), allocatable :: res(:,:) + ${t}$, intent(in) :: m1(:,:), m2(:,:) + ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:) + type(linalg_state_type), intent(out), optional :: err + ${t}$, allocatable :: temp(:,:), temp1(:,:) + integer :: p(6), num_present, m, n, k + integer, allocatable :: s(:,:) + + type(linalg_state_type) :: err0 + + p(1) = size(m1, 1) + p(2) = size(m2, 1) + p(3) = size(m2, 2) + + if (size(m1, 2) /= p(2)) then + err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m1=',shape(m1),& + ', m2=',shape(m2),'have incompatible sizes') + call linalg_error_handling(err0, err) + allocate(res(0, 0)) + return + end if + + num_present = 2 + if (present(m3)) then + + if (size(m3, 1) /= p(3)) then + err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m2=',shape(m2), & + ', m3=',shape(m3),'have incompatible sizes') + call linalg_error_handling(err0, err) + allocate(res(0, 0)) + return + end if + + p(3) = size(m3, 1) + p(4) = size(m3, 2) + num_present = num_present + 1 + end if + if (present(m4)) then + + if (size(m4, 1) /= p(4)) then + err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m3=',shape(m3), & + ', m4=',shape(m4),' have incompatible sizes') + call linalg_error_handling(err0, err) + allocate(res(0, 0)) + return + end if + + p(4) = size(m4, 1) + p(5) = size(m4, 2) + num_present = num_present + 1 + end if + if (present(m5)) then + + if (size(m5, 1) /= p(5)) then + err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m4=',shape(m4), & + ', m5=',shape(m5),' have incompatible sizes') + call linalg_error_handling(err0, err) + allocate(res(0, 0)) + return + end if + + p(5) = size(m5, 1) + p(6) = size(m5, 2) + num_present = num_present + 1 + end if + + allocate(res(p(1), p(num_present + 1))) + + if (num_present == 2) then + m = p(1) + n = p(3) + k = p(2) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, res, m) + return + end if + + ! Now num_present >= 3 + allocate(s(1:num_present - 1, 2:num_present)) + + s = matmul_chain_order(p(1: num_present + 1)) + + if (num_present == 3) then + res = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s, p(1:4)) + return + else if (num_present == 4) then + res = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p(1:5)) + return + end if + + ! Now num_present is 5 + + select case (s(1, 5)) + case (1) + ! m1*(m2*m3*m4*m5) + temp = matmul_chain_mult_${s}$_4(m2, m3, m4, m5, 2, s, p) + m = p(1) + n = p(6) + k = p(2) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, res, m) + case (2) + ! (m1*m2)*(m3*m4*m5) + m = p(1) + n = p(3) + k = p(2) + allocate(temp(m,n)) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m) + + temp1 = matmul_chain_mult_${s}$_3(m3, m4, m5, 3, s, p) + + k = n + n = p(6) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, res, m) + case (3) + ! (m1*m2*m3)*(m4*m5) + temp = matmul_chain_mult_${s}$_3(m1, m2, m3, 3, s, p) + + m = p(4) + n = p(6) + k = p(5) + allocate(temp1(m,n)) + call gemm('N', 'N', m, n, k, one_${s}$, m4, m, m5, k, zero_${s}$, temp1, m) + + k = m + m = p(1) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, res, m) + case (4) + ! (m1*m2*m3*m4)*m5 + temp = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p) + m = p(1) + n = p(6) + k = p(5) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m5, k, zero_${s}$, res, m) + case default + err0 = linalg_state_type(this,LINALG_INTERNAL_ERROR,"internal error: unexpected s(i,j)") + call linalg_error_handling(err0,err) + end select + + end subroutine stdlib_matmul_sub_${s}$ + + pure module function stdlib_matmul_pure_${s}$ (m1, m2, m3, m4, m5) result(r) + ${t}$, intent(in) :: m1(:,:), m2(:,:) + ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:) + ${t}$, allocatable :: r(:,:) + + call stdlib_matmul_sub(r, m1, m2, m3, m4, m5) + end function stdlib_matmul_pure_${s}$ + + module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5, err) result(r) + ${t}$, intent(in) :: m1(:,:), m2(:,:) + ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:) + type(linalg_state_type), intent(out) :: err + ${t}$, allocatable :: r(:,:) + + call stdlib_matmul_sub(r, m1, m2, m3, m4, m5, err=err) + end function stdlib_matmul_${s}$ + +#:endfor +end submodule stdlib_intrinsics_matmul diff --git a/test/intrinsics/test_intrinsics.fypp b/test/intrinsics/test_intrinsics.fypp index 8aefe09d3..dcbbe2e6a 100644 --- a/test/intrinsics/test_intrinsics.fypp +++ b/test/intrinsics/test_intrinsics.fypp @@ -7,6 +7,7 @@ module test_intrinsics use testdrive, only : new_unittest, unittest_type, error_type, check, skip_test use stdlib_kinds, only: sp, dp, xdp, qp, int8, int16, int32, int64 use stdlib_intrinsics + use stdlib_linalg_state, only: linalg_state_type, LINALG_VALUE_ERROR, operator(==) use stdlib_math, only: swap implicit none @@ -19,7 +20,8 @@ subroutine collect_suite(testsuite) testsuite = [ & new_unittest('sum', test_sum), & - new_unittest('dot_product', test_dot_product) & + new_unittest('dot_product', test_dot_product), & + new_unittest('matmul', test_matmul) & ] end subroutine @@ -249,6 +251,45 @@ subroutine test_dot_product(error) #:endfor end subroutine + +subroutine test_matmul(error) + type(error_type), allocatable, intent(out) :: error + type(linalg_state_type) :: linerr + real :: a(2, 3), b(3, 4), c(3, 2), d(2, 2) + + d = stdlib_matmul(a, b, c, err=linerr) + call check(error, linerr == LINALG_VALUE_ERROR, "incompatible matrices are considered compatible") + if (allocated(error)) return + + #:for k, t, s in R_KINDS_TYPES + block + ${t}$ :: x(10,20), y(20,30), z(30,10), r(10,10), r1(10,10) + call random_number(x) + call random_number(y) + call random_number(z) + + r = stdlib_matmul(x, y, z) ! the optimal ordering would be (x(yz)) + r1 = matmul(matmul(x, y), z) ! the opposite order to induce a difference + + call check(error, all(abs(r-r1) <= epsilon(0._${k}$) * 300), "real, ${k}$, 3 args: error too large") + if (allocated(error)) return + end block + + block + ${t}$ :: x(10,20), y(20,30), z(30,10), w(10, 20), r(10,20), r1(10,20) + call random_number(x) + call random_number(y) + call random_number(z) + call random_number(w) + + r = stdlib_matmul(x, y, z, w) ! the optimal order would be ((x(yz))w) + r1 = matmul(matmul(x, y), matmul(z, w)) + + call check(error, all(abs(r-r1) <= epsilon(0._${k}$) * 1500), "real, ${k}$, 4 args: error too large") + if (allocated(error)) return + end block + #:endfor +end subroutine test_matmul end module test_intrinsics @@ -276,4 +317,4 @@ program tester write(error_unit, '(i0, 1x, a)') stat, "test(s) failed!" error stop end if -end program \ No newline at end of file +end program