getfem-commits
[Top][All Lists]
Advanced

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

[Getfem-commits] (no subject)


From: Konstantinos Poulios
Subject: [Getfem-commits] (no subject)
Date: Fri, 20 Oct 2023 05:00:27 -0400 (EDT)

branch: improve-expm-performance
commit ee139e697b50db9198ef8257e2c492202b037a35
Author: Konstantinos Poulios <logari81@gmail.com>
AuthorDate: Fri Oct 20 10:59:42 2023 +0200

    Use a Pade approximation for expm, ported from Eigen/Unsupported
---
 src/getfem_plasticity.cc | 472 ++++++++++++++++++++++++++++++++++++++---------
 1 file changed, 386 insertions(+), 86 deletions(-)

diff --git a/src/getfem_plasticity.cc b/src/getfem_plasticity.cc
index 5a0ec9c0..69ca66b2 100644
--- a/src/getfem_plasticity.cc
+++ b/src/getfem_plasticity.cc
@@ -46,109 +46,409 @@ namespace getfem {
     { mi.resize(2); mi[0] = mi[1] = N; }
 
 
-  bool expm(const base_matrix &a_, base_matrix &aexp, scalar_type tol=1e-15) {
+  inline void matmul(base_matrix &aa,base_matrix &bb,base_matrix &cc)
+    {gmm::mult(aa,bb,cc);}
 
-    const size_type itmax = 40;
-    base_matrix a(a_);
-    // scale input matrix a
-    int e;
-    frexp(gmm::mat_norminf(a), &e);
-    e = std::max(0, std::min(1023, e));
-    gmm::scale(a, pow(scalar_type(2),-scalar_type(e)));
-
-    base_matrix atmp(a), an(a);
-    gmm::copy(a, aexp);
-    gmm::add(gmm::identity_matrix(), aexp);
-    scalar_type factn(1);
+  bool expm(const base_matrix &a_, base_matrix &aexp) {
+
+    const size_type N = gmm::mat_nrows(a_);
     bool success(false);
-    for (size_type n=2; n < itmax; ++n) {
-      factn /= scalar_type(n);
-      gmm::mult(an, a, atmp);
-      gmm::copy(atmp, an);
-      gmm::scale(atmp, factn);
-      gmm::add(atmp, aexp);
-      if (gmm::mat_euclidean_norm(atmp) < tol) {
-        success = true;
-        break;
-      }
+
+    // Pade approximation ported from Eigen/Unsupported
+    base_matrix a(a_);
+    gmm::clear(aexp.as_vector());
+    base_matrix tmp(aexp), v(aexp), u(aexp); // Pade approximant is (v+u)/(v-u)
+    const scalar_type l1norm = gmm::mat_norminf(a_);
+    int e = 0; // squarings
+    if (l1norm < 1.495585217958292e-002) { // matrix_exp_pade3(a, u, v)
+      const static std::array<scalar_type,4> b{120,60,12,1};
+      base_matrix a2(a);
+      matmul(a, a, a2);
+      gmm::copy(gmm::scaled(a2,b[2]), v);   // v = b2*A2 + b0*I
+      gmm::copy(gmm::scaled(a2,b[3]), u);   // u = b3*A2 + b1*I
+      for (size_type ij=0; ij < N; ++ij)
+        { v(ij,ij) += b[0]; u(ij,ij) += b[1]; }
+    } else if (l1norm < 2.539398330063230e-001) { // matrix_exp_pade5(a, u, v)
+      const static std::array<scalar_type,6> b{30240,15120,3360,420,30,1};
+      base_matrix a2(a), a4(a);
+      matmul(a, a, a2);
+      matmul(a2, a2, a4);
+      gmm::add(gmm::scaled(a4,b[4]),    // v = b4*A4 + b2*A2 + b0*I
+               gmm::scaled(a2,b[2]), v);
+      gmm::add(gmm::scaled(a4,b[5]),    // u = b5*A4 + b3*A2 + b1*I
+               gmm::scaled(a2,b[3]), u);
+      for (size_type ij=0; ij < N; ++ij)
+        { v(ij,ij) += b[0]; u(ij,ij) += b[1]; }
+    } else if (l1norm < 9.504178996162932e-001) { // matrix_exp_pade7(a, u, v)
+      const static std::array<scalar_type,8>
+        b{17297280,8648640,1995840,277200,25200,1512,56,1};
+      base_matrix a2(a), a4(a), a6(a);
+      matmul(a, a, a2);
+      matmul(a2, a2, a4);
+      matmul(a2, a4, a6);
+      gmm::add(gmm::scaled(a6,b[6]),    // v = b6*A6 + b4*A4 + b2*A2 + b0*I
+               gmm::scaled(a4,b[4]), v);
+      gmm::add(gmm::scaled(a2,b[2]), v);
+      gmm::add(gmm::scaled(a6,b[7]),    // u = b7*A6 + b5*A4 + b3*A2 + b1*I
+               gmm::scaled(a4,b[5]), u);
+      gmm::add(gmm::scaled(a2,b[3]), u);
+      for (size_type ij=0; ij < N; ++ij)
+        { v(ij,ij) += b[0]; u(ij,ij) += b[1]; }
+    } else if (l1norm < 2.097847961257068e+000) { // matrix_exp_pade9(a, u, v)
+      const static std::array<scalar_type,10>
+        b{17643225600,8821612800,2075673600,302702400,30270240,2162160,
+          110880,3960,90,1};
+      base_matrix a2(a), a4(a), a6(a), a8(a);
+      matmul(a, a, a2);
+      matmul(a2, a2, a4);
+      matmul(a2, a4, a6);
+      matmul(a4, a4, a8);
+      gmm::add(gmm::scaled(a8,b[8]),    // v = b8*A8+b6*A6+b4*A4+b2*A2+b0*I
+               gmm::scaled(a6,b[6]), v);
+      gmm::add(gmm::scaled(a4,b[4]), v);
+      gmm::add(gmm::scaled(a2,b[2]), v);
+      gmm::add(gmm::scaled(a8,b[9]),    // u = b9*A8+b7*A6+b5*A4+b3*A2+b1*I
+               gmm::scaled(a6,b[7]), u);
+      gmm::add(gmm::scaled(a4,b[5]), u);
+      gmm::add(gmm::scaled(a2,b[3]), u);
+      for (size_type ij=0; ij < N; ++ij)
+        { v(ij,ij) += b[0]; u(ij,ij) += b[1]; }
+    } else { // matrix_exp_pade13(a, U, V);
+      const scalar_type maxnorm = 5.371920351148152;
+      frexp(l1norm / maxnorm, &e);
+      if (e <= 0) e = 0;
+      else for (auto &&val : a.as_vector()) { val = ldexp(val,-e); }
+           // <==> gmm::scale(a, pow(scalar_type(2),-scalar_type(e)));
+      const static std::array<scalar_type,14>
+        b{64764752532480000,32382376266240000,7771770303897600,
+          1187353796428800,129060195264000,10559470521600,670442572800,
+          33522128640,1323241920,40840800,960960,16380,182,1};
+      base_matrix a2(a), a4(a), a6(a);
+      matmul(a, a, a2);
+      matmul(a2, a2, a4);
+      matmul(a2, a4, a6);
+      gmm::add(gmm::scaled(a6,b[12]),
+               gmm::scaled(a4,b[10]), tmp);
+      gmm::add(gmm::scaled(a2,b[8]), tmp);
+      matmul(a6, tmp, v);             // v = b12*A12+b10*A10+b8*A8
+      gmm::add(gmm::scaled(a6,b[6]), v); //   + b6*A6+b4*A4+b2*A2+b0*I
+      gmm::add(gmm::scaled(a4,b[4]), v);
+      gmm::add(gmm::scaled(a2,b[2]), v);
+      gmm::add(gmm::scaled(a6,b[13]),
+               gmm::scaled(a4,b[11]), tmp);
+      gmm::add(gmm::scaled(a2,b[9]), tmp);
+      matmul(a6, tmp, u);             // u = b13*A12+b11*A10+b9*A8
+      gmm::add(gmm::scaled(a6,b[7]), u); //   + b7*A6+b5*A4+b3*A2+b1*I
+      gmm::add(gmm::scaled(a4,b[5]), u);
+      gmm::add(gmm::scaled(a2,b[3]), u);
+      for (size_type ij=0; ij < N; ++ij)
+        { v(ij,ij) += b[0]; u(ij,ij) += b[1]; }
     }
-    // unscale result
-    for (int i=0; i < e; ++i) {
-      gmm::mult(aexp, aexp, atmp);
-      gmm::copy(atmp, aexp);
+    std::swap(u, tmp);
+    matmul(a, tmp, u);              // u <-- A*u
+
+    gmm::add(v,gmm::scaled(u,-1),tmp); // tmp = denom = v-u
+    gmm::lu_inverse(tmp);              // tmp = (v-u)^-1
+    gmm::add(u,v);                     // v <-- numer = v+u;
+    matmul(tmp,v,aexp);
+    success = true;
+
+    for (int i=0; i < e; ++i) { // unscale result
+      std::swap(aexp, tmp);
+      matmul(tmp, tmp, aexp);
     }
     return success;
   }
 
-  bool expm_deriv(const base_matrix &a_, base_tensor &daexp,
-                  base_matrix *paexp=NULL, scalar_type tol=1e-15) {
-
-    const size_type itmax = 40;
-    size_type N = gmm::mat_nrows(a_);
-    size_type N2 = N*N;
 
-    base_matrix a(a_);
-    // scale input matrix a
-    int e;
-    frexp(gmm::mat_norminf(a), &e);
-    e = std::max(0, std::min(1023, e));
-    scalar_type scale = pow(scalar_type(2),-scalar_type(e));
-    gmm::scale(a, scale);
-
-    base_vector factnn(itmax);
-    base_matrix atmp(a), an(a), aexp(a);
-    base_tensor ann(bgeot::multi_index(N,N,itmax));
-    gmm::add(gmm::identity_matrix(), aexp);
-    gmm::copy(gmm::identity_matrix(), atmp);
-    std::copy(atmp.begin(), atmp.end(), ann.begin());
-    factnn[1] = 1;
-    std::copy(a.begin(), a.end(), ann.begin()+N2);
-    size_type n;
-    bool success(false);
-    for (n=2; n < itmax; ++n) {
-      factnn[n] = factnn[n-1]/scalar_type(n);
-      gmm::mult(an, a, atmp);
-      gmm::copy(atmp, an);
-      std::copy(an.begin(), an.end(), ann.begin()+n*N2);
-      gmm::scale(atmp, factnn[n]);
-      gmm::add(atmp, aexp);
-      if (gmm::mat_euclidean_norm(atmp) < tol) {
-        success = true;
-        break;
-      }
-    }
 
-    if (!success)
-      return false;
+  bool expm_deriv(const base_matrix &a_, base_tensor &daexp) {
 
+    size_type N = gmm::mat_nrows(a_);
+    base_matrix a(a_), tmp(a_);
+    gmm::clear(tmp.as_vector());
+    base_matrix aexp(tmp), v(tmp), u(tmp), // Pade approximant is (v+u)/(v-u)
+                tmp_(tmp), dv_(tmp), du_(tmp);
     gmm::clear(daexp.as_vector());
-    gmm::scale(factnn, scale);
-    for (--n; n >= 1; --n) {
-      scalar_type factn = factnn[n];
-      for (size_type m=1; m <= n; ++m)
-        for (size_type l=0; l < N; ++l)
-          for (size_type k=0; k < N; ++k)
-            for (size_type j=0; j < N; ++j)
-              for (size_type i=0; i < N; ++i)
-                daexp(i,j,k,l) += factn*ann(i,k,m-1)*ann(l,j,n-m);
-    }
+    base_tensor dv(daexp), du(daexp);
+    const scalar_type l1norm = gmm::mat_norminf(a_);
+    int e = 0; // squarings
+    if (l1norm < 1.495585217958292e-002) { // matrix_exp_pade3(a, u, v)
+      const static std::array<scalar_type,4> b{120,60,12,1};
+      base_matrix a2(a);
+      matmul(a, a, a2);
+      gmm::copy(gmm::scaled(a2,b[2]), v);   // v = b2*A2 + b0*I
+      gmm::copy(gmm::scaled(a2,b[3]), u);   // u = b3*A2 + b1*I
+      for (size_type ij=0; ij < N; ++ij)
+        { v(ij,ij) += b[0]; u(ij,ij) += b[1]; }
+
+      for (size_type l=0; l < N; ++l) // tmp derivative of a2
+        for (size_type k=0; k < N; ++k) {
+          gmm::clear(dv_); gmm::clear(du_);
+          for (size_type ij=0; ij < N; ++ij) {
+            const auto &al=a(l,ij), &ak=a(ij,k);
+            dv_(k,ij) += b[2]*al;   dv_(ij,l) += b[2]*ak;
+            du_(k,ij) += b[3]*al;   du_(ij,l) += b[3]*ak;
+          }
+          std::swap(du_,tmp); // derivative of u <-- A*u
+          matmul(a,tmp,du_);
+          for (size_type j=0; j < N; ++j) // i == k
+            du_(k,j) += u(l,j);
 
-    // unscale result
-    base_matrix atmp1(a), atmp2(a);
-    for (int i=0; i < e; ++i) {
+          std::copy(dv_.begin(),dv_.end(), &dv(0,0,k,l));
+          std::copy(du_.begin(),du_.end(), &du(0,0,k,l));
+        }
+    } else if (l1norm < 2.539398330063230e-001) { // matrix_exp_pade5(a, u, v)
+      const static std::array<scalar_type,6> b{30240,15120,3360,420,30,1};
+      base_matrix a2(a), a4(a);
+      matmul(a, a, a2);
+      matmul(a2, a2, a4);
+      gmm::add(gmm::scaled(a4,b[4]),    // v = b4*A4 + b2*A2 + b0*I
+               gmm::scaled(a2,b[2]), v);
+      gmm::add(gmm::scaled(a4,b[5]),    // u = b5*A4 + b3*A2 + b1*I
+               gmm::scaled(a2,b[3]), u);
+      for (size_type ij=0; ij < N; ++ij)
+        { v(ij,ij) += b[0]; u(ij,ij) += b[1]; }
+
+      base_matrix da2(aexp); // zero init
       for (size_type l=0; l < N; ++l)
         for (size_type k=0; k < N; ++k) {
-          std::copy(&daexp(0,0,k,l), &daexp(0,0,k,l)+N*N, atmp.begin());
-          gmm::mult(atmp, aexp, atmp1);
-          gmm::mult(aexp, atmp, atmp2);
-          gmm::add(atmp1, atmp2, atmp);
-          std::copy(atmp.begin(), atmp.end(), &daexp(0,0,k,l));
+          gmm::clear(da2); gmm::clear(dv_); gmm::clear(du_);
+          for (size_type ij=0; ij < N; ++ij) {
+            const auto &al=a(l,ij), &ak=a(ij,k);
+            da2(k,ij) += al;        da2(ij,l) += ak;
+            dv_(k,ij) += b[2]*al;   dv_(ij,l) += b[2]*ak;
+            du_(k,ij) += b[3]*al;   du_(ij,l) += b[3]*ak;
+          }
+          matmul(a2,da2,tmp);
+          matmul(da2,a2,tmp_);
+          gmm::add(tmp_,tmp);        // tmp derivative of a4
+          gmm::add(gmm::scaled(tmp,b[4]), dv_);
+          gmm::add(gmm::scaled(tmp,b[5]), du_);
+
+          std::swap(du_,tmp); // derivative of u <-- A*u
+          matmul(a,tmp,du_);
+          for (size_type j=0; j < N; ++j) // i == k
+            du_(k,j) += u(l,j);
+
+          std::copy(dv_.begin(),dv_.end(), &dv(0,0,k,l));
+          std::copy(du_.begin(),du_.end(), &du(0,0,k,l));
+        }
+    } else if (l1norm < 9.504178996162932e-001) { // matrix_exp_pade7(a, u, v)
+      const static std::array<scalar_type,8>
+        b{17297280,8648640,1995840,277200,25200,1512,56,1};
+      base_matrix a2(a), a4(a), a6(a);
+      matmul(a, a, a2);
+      matmul(a2, a2, a4);
+      matmul(a2, a4, a6);
+      gmm::add(gmm::scaled(a6,b[6]),    // v = b6*A6 + b4*A4 + b2*A2 + b0*I
+               gmm::scaled(a4,b[4]), v);
+      gmm::add(gmm::scaled(a2,b[2]), v);
+      gmm::add(gmm::scaled(a6,b[7]),    // u = b7*A6 + b5*A4 + b3*A2 + b1*I
+               gmm::scaled(a4,b[5]), u);
+      gmm::add(gmm::scaled(a2,b[3]), u);
+      for (size_type ij=0; ij < N; ++ij)
+        { v(ij,ij) += b[0]; u(ij,ij) += b[1]; }
+
+      base_matrix da2(aexp); // zero init
+      for (size_type l=0; l < N; ++l)
+        for (size_type k=0; k < N; ++k) {
+          gmm::clear(da2); gmm::clear(dv_); gmm::clear(du_);
+          for (size_type ij=0; ij < N; ++ij) {
+            const auto &al=a(l,ij), &ak=a(ij,k);
+            da2(k,ij) += al;        da2(ij,l) += ak;
+            dv_(k,ij) += b[2]*al;   dv_(ij,l) += b[2]*ak;
+            du_(k,ij) += b[3]*al;   du_(ij,l) += b[3]*ak;
+          }
+          matmul(a2,da2,tmp);
+          matmul(da2,a2,tmp_);
+          gmm::add(tmp_,tmp);        // tmp derivative of a4
+          gmm::add(gmm::scaled(tmp,b[4]), dv_);
+          gmm::add(gmm::scaled(tmp,b[5]), du_);
+
+          matmul(a2,tmp,tmp_);
+          matmul(da2,a4,tmp);
+          gmm::add(tmp_,tmp);        // tmp derivative of a6
+          gmm::add(gmm::scaled(tmp,b[6]), dv_);
+          gmm::add(gmm::scaled(tmp,b[7]), du_);
+
+          std::swap(du_,tmp); // derivative of u <-- A*u
+          matmul(a,tmp,du_);
+          for (size_type j=0; j < N; ++j) // i == k
+            du_(k,j) += u(l,j);
+
+          std::copy(dv_.begin(),dv_.end(), &dv(0,0,k,l));
+          std::copy(du_.begin(),du_.end(), &du(0,0,k,l));
+        }
+    } else if (l1norm < 2.097847961257068e+000) { // matrix_exp_pade9(a, u, v)
+      const static std::array<scalar_type,10>
+        b{17643225600,8821612800,2075673600,302702400,30270240,2162160,
+          110880,3960,90,1};
+      base_matrix a2(a), a4(a), a6(a), a8(a);
+      matmul(a, a, a2);
+      matmul(a2, a2, a4);
+      matmul(a2, a4, a6);
+      matmul(a4, a4, a8);
+      gmm::add(gmm::scaled(a8,b[8]),    // v = b8*A8+b6*A6+b4*A4+b2*A2+b0*I
+               gmm::scaled(a6,b[6]), v);
+      gmm::add(gmm::scaled(a4,b[4]), v);
+      gmm::add(gmm::scaled(a2,b[2]), v);
+      gmm::add(gmm::scaled(a8,b[9]),    // u = b9*A8+b7*A6+b5*A4+b3*A2+b1*I
+               gmm::scaled(a6,b[7]), u);
+      gmm::add(gmm::scaled(a4,b[5]), u);
+      gmm::add(gmm::scaled(a2,b[3]), u);
+      for (size_type ij=0; ij < N; ++ij)
+        { v(ij,ij) += b[0]; u(ij,ij) += b[1]; }
+
+      base_matrix da2(aexp), da4(aexp); // zero init
+      for (size_type l=0; l < N; ++l)
+        for (size_type k=0; k < N; ++k) {
+          gmm::clear(da2); gmm::clear(dv_); gmm::clear(du_);
+          for (size_type ij=0; ij < N; ++ij) {
+            const auto &al=a(l,ij), &ak=a(ij,k);
+            da2(k,ij) += al;        da2(ij,l) += ak;
+            dv_(k,ij) += b[2]*al;   dv_(ij,l) += b[2]*ak;
+            du_(k,ij) += b[3]*al;   du_(ij,l) += b[3]*ak;
+          }
+          matmul(a2,da2,tmp);
+          matmul(da2,a2,da4);
+          gmm::add(tmp,da4);
+          gmm::add(gmm::scaled(da4,b[4]), dv_);
+          gmm::add(gmm::scaled(da4,b[5]), du_);
+
+          matmul(a2,da4,tmp_);
+          matmul(da2,a4,tmp);
+          gmm::add(tmp_,tmp);        // tmp derivative of a6
+          gmm::add(gmm::scaled(tmp,b[6]), dv_);
+          gmm::add(gmm::scaled(tmp,b[7]), du_);
+
+          matmul(a4,da4,tmp);
+          matmul(da4,a4,tmp_);
+          gmm::add(tmp_,tmp);        // tmp derivative of a8
+          gmm::add(gmm::scaled(tmp,b[8]), dv_);
+          gmm::add(gmm::scaled(tmp,b[9]), du_);
+
+          std::swap(du_,tmp); // derivative of u <-- A*u
+          matmul(a,tmp,du_);
+          for (size_type j=0; j < N; ++j) // i == k
+            du_(k,j) += u(l,j);
+
+          std::copy(dv_.begin(),dv_.end(), &dv(0,0,k,l));
+          std::copy(du_.begin(),du_.end(), &du(0,0,k,l));
+        }
+    } else { // matrix_exp_pade13(a, U, V);
+      const scalar_type maxnorm = 5.371920351148152;
+      frexp(l1norm / maxnorm, &e);
+      if (e <= 0) e = 0;
+      else for (auto &&val : a.as_vector()) { val = ldexp(val,-e); }
+           // <==> gmm::scale(a, pow(scalar_type(2),-scalar_type(e)));
+      const static std::array<scalar_type,14>
+        b{64764752532480000,32382376266240000,7771770303897600,
+          1187353796428800,129060195264000,10559470521600,670442572800,
+          33522128640,1323241920,40840800,960960,16380,182,1};
+      base_matrix a2(a), a4(a), a6(a), v_(a), u_(a);
+      matmul(a, a, a2);
+      matmul(a2, a2, a4);
+      matmul(a2, a4, a6);
+      gmm::add(gmm::scaled(a6,b[12]),
+               gmm::scaled(a4,b[10]), v_);
+      gmm::add(gmm::scaled(a2,b[8]), v_);
+      matmul(a6, v_, v);              // v = b12*A12+b10*A10+b8*A8
+      gmm::add(gmm::scaled(a6,b[6]), v); //   + b6*A6+b4*A4+b2*A2+b0*I
+      gmm::add(gmm::scaled(a4,b[4]), v);
+      gmm::add(gmm::scaled(a2,b[2]), v);
+
+      gmm::add(gmm::scaled(a6,b[13]),
+               gmm::scaled(a4,b[11]), u_);
+      gmm::add(gmm::scaled(a2,b[9]), u_);
+      matmul(a6, u_, u);              // u = b13*A12+b11*A10+b9*A8
+      gmm::add(gmm::scaled(a6,b[7]), u); //   + b7*A6+b5*A4+b3*A2+b1*I
+      gmm::add(gmm::scaled(a4,b[5]), u);
+      gmm::add(gmm::scaled(a2,b[3]), u);
+      for (size_type ij=0; ij < N; ++ij)
+        { v(ij,ij) += b[0]; u(ij,ij) += b[1]; }
+
+      base_matrix da2(aexp), da4(aexp), da6(aexp),
+                  dv__(aexp), du__(aexp);
+      for (size_type l=0; l < N; ++l)
+        for (size_type k=0; k < N; ++k) {
+          gmm::clear(da2); gmm::clear(dv_); gmm::clear(du_);
+          gmm::clear(dv__); gmm::clear(du__);
+          for (size_type ij=0; ij < N; ++ij) {
+            const auto &al=a(l,ij), &ak=a(ij,k);
+            da2(k,ij) += al;        da2(ij,l) += ak;
+            dv_(k,ij) += b[2]*al;   dv_(ij,l) += b[2]*ak;
+            du_(k,ij) += b[3]*al;   du_(ij,l) += b[3]*ak;
+            dv__(k,ij) += b[8]*al;  dv__(ij,l) += b[8]*ak;
+            du__(k,ij) += b[9]*al;  du__(ij,l) += b[9]*ak;
+          }
+          matmul(a2,da2,da4);
+          matmul(da2,a2,tmp); gmm::add(tmp,da4);
+          gmm::add(gmm::scaled(da4,b[4]), dv_);
+          gmm::add(gmm::scaled(da4,b[5]), du_);
+          gmm::add(gmm::scaled(da4,b[10]), dv__);
+          gmm::add(gmm::scaled(da4,b[11]), du__);
+
+          matmul(a2,da4,da6);
+          matmul(da2,a4,tmp); gmm::add(tmp,da6);
+          gmm::add(gmm::scaled(da6,b[6]), dv_);
+          gmm::add(gmm::scaled(da6,b[7]), du_);
+          gmm::add(gmm::scaled(da6,b[12]), dv__);
+          gmm::add(gmm::scaled(da6,b[13]), du__);
+
+          matmul(a6,dv__,tmp); gmm::add(tmp, dv_);
+          matmul(da6,v_,tmp);  gmm::add(tmp, dv_);
+
+          matmul(a6,du__,tmp); gmm::add(tmp, du_);
+          matmul(da6,u_,tmp);  gmm::add(tmp, du_);
+
+          std::swap(du_,tmp); // derivative of u <-- A*u
+          matmul(a,tmp,du_);
+          for (size_type j=0; j < N; ++j) // i == k
+            du_(k,j) += u(l,j);
+
+          std::copy(dv_.begin(),dv_.end(), &dv(0,0,k,l));
+          std::copy(du_.begin(),du_.end(), &du(0,0,k,l));
         }
-      gmm::mult(aexp, aexp, atmp);
-      gmm::copy(atmp, aexp);
     }
+    std::swap(u, tmp);
+    matmul(a, tmp, u);              // u <-- A*u
+
+    base_matrix inv_denom(v);
+    gmm::add(gmm::scaled(u,-1),inv_denom); // denom = v-u
+    gmm::lu_inverse(inv_denom);
+
+    gmm::add(u,v,tmp);                     // tmp = numer = v+u
+    matmul(inv_denom,tmp,aexp);
+
+    for (size_type l=0; l < N; ++l)
+      for (size_type k=0; k < N; ++k) { // daexp_kl= D\(dN_kl-dD_kl*aexp)
+        std::copy(&dv(0,0,k,l),&dv(0,0,k,l)+N*N, tmp_.begin());
+        std::copy(&du(0,0,k,l),&du(0,0,k,l)+N*N, tmp.begin());
+        gmm::add(gmm::scaled(tmp_/*dv*/,-1),tmp/*du*/); // tmp = -(dv-du)
+        matmul(tmp,aexp,tmp_);
+        std::copy(&du(0,0,k,l),&du(0,0,k,l)+N*N, tmp.begin());
+        gmm::add(tmp/*du*/, tmp_);
+        std::copy(&dv(0,0,k,l),&dv(0,0,k,l)+N*N, tmp.begin());
+        gmm::add(tmp/*dv*/, tmp_); // tmp_ = (dv+du)_kl-(dv-du)_kl*aexp 
+        matmul(inv_denom, tmp_, tmp);
+        std::copy(tmp.begin(), tmp.end(), &daexp(0,0,k,l));
+      }
+    if (e)
+      for (auto &&val : daexp.as_vector()) { val = ldexp(val,-e); }
 
-    if (paexp) gmm::copy(aexp, *paexp);
+    for (int i=0; i < e; ++i) { // unscale result
+      for (size_type l=0; l < N; ++l)
+        for (size_type k=0; k < N; ++k) {
+          std::copy(&daexp(0,0,k,l), &daexp(0,0,k,l)+N*N, tmp.begin());
+          matmul(tmp, aexp, u); // u,v used a temporaries
+          matmul(aexp, tmp, v); //
+          gmm::add(u, v, tmp);
+          std::copy(tmp.begin(), tmp.end(), &daexp(0,0,k,l));
+        }
+      std::swap(aexp,tmp);
+      matmul(tmp, tmp, aexp);
+    }
     return true;
   }
 



reply via email to

[Prev in Thread] Current Thread [Next in Thread]