octave-maintainers
[Top][All Lists]
Advanced

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

[PATCH 4 of 4] Implement diag + sparse, diag - sparse, sparse + diag, s


From: Jason Riedy
Subject: [PATCH 4 of 4] Implement diag + sparse, diag - sparse, sparse + diag, sparse - diag
Date: Mon, 09 Mar 2009 17:52:22 -0400

# HG changeset patch
# User Jason Riedy <address@hidden>
# Date 1236635354 14400
# Node ID 951723129240c9880b6675fefbbb06a8ba267fe1
# Parent  64103815307e702a6de6ddc2115e84c3636369cf
Implement diag + sparse, diag - sparse, sparse + diag, sparse - diag.

>From 2c4c62a9669d1d43885d5e84c07c17a60321b10b Mon Sep 17 00:00:00 2001
Date: Mon, 9 Mar 2009 17:45:22 -0400
This does not use the typical sparse-mx-ops generator.  I suspect the
sematics of elementwise multiplication and division to be rather
controversial, so they are not included.  If comparison operations are
added, the implementation should be shifted over to use the typical
generator.

The template in Sparse-diag-op-defs.h likely could use const bools
rather than functional argument operations.  I haven't measured which
is optimized more effectively.

Also, the Octave binding layer in op-dm-scm.cc likely could use all
sorts of macro or template trickery, but it's far easier to let Emacs
handle it for now.  That would be worth revisiting if further
elementwise sparse and diagonal operations are added.

Signed-off-by: Jason Riedy <address@hidden>
---
 liboctave/CSparse.cc            |   62 ++++++
 liboctave/CSparse.h             |   14 ++
 liboctave/ChangeLog             |   37 ++++
 liboctave/Sparse-diag-op-defs.h |  112 +++++++++++
 liboctave/dSparse.cc            |   24 +++
 liboctave/dSparse.h             |    5 +
 src/ChangeLog                   |   22 +++
 src/OPERATORS/op-dm-scm.cc      |  388 +++++++++++++++++++++++++++++++++++++++
 src/OPERATORS/op-dm-sm.cc       |  128 +++++++++++++
 test/ChangeLog                  |    4 +
 test/test_diag_perm.m           |   25 +++
 11 files changed, 821 insertions(+), 0 deletions(-)

diff --git a/liboctave/CSparse.cc b/liboctave/CSparse.cc
--- a/liboctave/CSparse.cc
+++ b/liboctave/CSparse.cc
@@ -7686,6 +7686,68 @@
     octave_impl::conj_val<ComplexDiagMatrix::element_type> ());
 }
 
+SparseComplexMatrix
+operator + (const ComplexDiagMatrix& d, const SparseMatrix& a)
+{
+  return octave_impl::do_add_dm_sm<SparseComplexMatrix> (d, a);
+}
+SparseComplexMatrix
+operator + (const DiagMatrix& d, const SparseComplexMatrix& a)
+{
+  return octave_impl::do_add_dm_sm<SparseComplexMatrix> (d, a);
+}
+SparseComplexMatrix
+operator + (const ComplexDiagMatrix& d, const SparseComplexMatrix& a)
+{
+  return octave_impl::do_add_dm_sm<SparseComplexMatrix> (d, a);
+}
+SparseComplexMatrix
+operator + (const SparseMatrix& a, const ComplexDiagMatrix& d)
+{
+  return octave_impl::do_add_sm_dm<SparseComplexMatrix> (a, d);
+}
+SparseComplexMatrix
+operator + (const SparseComplexMatrix& a, const DiagMatrix& d)
+{
+  return octave_impl::do_add_sm_dm<SparseComplexMatrix> (a, d);
+}
+SparseComplexMatrix
+operator + (const SparseComplexMatrix&a, const ComplexDiagMatrix& d)
+{
+  return octave_impl::do_add_sm_dm<SparseComplexMatrix> (a, d);
+}
+
+SparseComplexMatrix
+operator - (const ComplexDiagMatrix& d, const SparseMatrix& a)
+{
+  return octave_impl::do_sub_dm_sm<SparseComplexMatrix> (d, a);
+}
+SparseComplexMatrix
+operator - (const DiagMatrix& d, const SparseComplexMatrix& a)
+{
+  return octave_impl::do_sub_dm_sm<SparseComplexMatrix> (d, a);
+}
+SparseComplexMatrix
+operator - (const ComplexDiagMatrix& d, const SparseComplexMatrix& a)
+{
+  return octave_impl::do_sub_dm_sm<SparseComplexMatrix> (d, a);
+}
+SparseComplexMatrix
+operator - (const SparseMatrix& a, const ComplexDiagMatrix& d)
+{
+  return octave_impl::do_sub_sm_dm<SparseComplexMatrix> (a, d);
+}
+SparseComplexMatrix
+operator - (const SparseComplexMatrix& a, const DiagMatrix& d)
+{
+  return octave_impl::do_sub_sm_dm<SparseComplexMatrix> (a, d);
+}
+SparseComplexMatrix
+operator - (const SparseComplexMatrix&a, const ComplexDiagMatrix& d)
+{
+  return octave_impl::do_sub_sm_dm<SparseComplexMatrix> (a, d);
+}
+
 // FIXME -- it would be nice to share code among the min/max
 // functions below.
 
diff --git a/liboctave/CSparse.h b/liboctave/CSparse.h
--- a/liboctave/CSparse.h
+++ b/liboctave/CSparse.h
@@ -496,6 +496,20 @@
 extern OCTAVE_API SparseComplexMatrix mul_trans (const SparseComplexMatrix&, 
const ComplexDiagMatrix&);
 extern OCTAVE_API SparseComplexMatrix mul_herm (const SparseComplexMatrix&, 
const ComplexDiagMatrix&);
 
+extern OCTAVE_API SparseComplexMatrix operator + (const ComplexDiagMatrix&, 
const SparseMatrix&);
+extern OCTAVE_API SparseComplexMatrix operator + (const DiagMatrix&, const 
SparseComplexMatrix&);
+extern OCTAVE_API SparseComplexMatrix operator + (const ComplexDiagMatrix&, 
const SparseComplexMatrix&);
+extern OCTAVE_API SparseComplexMatrix operator + (const SparseMatrix&, const 
ComplexDiagMatrix&);
+extern OCTAVE_API SparseComplexMatrix operator + (const SparseComplexMatrix&, 
const DiagMatrix&);
+extern OCTAVE_API SparseComplexMatrix operator + (const SparseComplexMatrix&, 
const ComplexDiagMatrix&);
+
+extern OCTAVE_API SparseComplexMatrix operator - (const ComplexDiagMatrix&, 
const SparseMatrix&);
+extern OCTAVE_API SparseComplexMatrix operator - (const DiagMatrix&, const 
SparseComplexMatrix&);
+extern OCTAVE_API SparseComplexMatrix operator - (const ComplexDiagMatrix&, 
const SparseComplexMatrix&);
+extern OCTAVE_API SparseComplexMatrix operator - (const SparseMatrix&, const 
ComplexDiagMatrix&);
+extern OCTAVE_API SparseComplexMatrix operator - (const SparseComplexMatrix&, 
const DiagMatrix&);
+extern OCTAVE_API SparseComplexMatrix operator - (const SparseComplexMatrix&, 
const ComplexDiagMatrix&);
+
 extern OCTAVE_API SparseComplexMatrix min (const Complex& c, 
                                const SparseComplexMatrix& m);
 extern OCTAVE_API SparseComplexMatrix min (const SparseComplexMatrix& m, 
diff --git a/liboctave/ChangeLog b/liboctave/ChangeLog
--- a/liboctave/ChangeLog
+++ b/liboctave/ChangeLog
@@ -1,3 +1,40 @@
+2009-03-09  Jason Riedy  <address@hidden>
+
+       * Sparse-diag-op-defs.h (octave_impl::inner_do_add_sm_dm): New
+       template function.  Implementation for adding sparse and diagonal
+       matrices.  Takes two functional arguments, opa and opd, to
+       generate both subtraction variants.
+       (octave_impl::do_commutative_add_dm_sm): New template function.
+       Ensure A+D and D+A use the same generated code.
+       (octave_impl::do_add_dm_sm): New template function.  Check
+       arguments for diag + sparse and call inner routine.
+       (octave_impl::do_sub_dm_sm): New template function.  Check
+       arguments for diag - sparse and call inner routine.
+       (octave_impl::do_add_sm_dm): New template function.  Check
+       arguments for sparse + diag and call inner routine.
+       (octave_impl::do_sub_sm_dm): New template function.  Check
+       arguments for sparse - diag and call inner routine.
+
+       * dSparse.h (operator +): Declare overrides for real diag +
+       sparse.
+       (operator -): Declare overrides for real diag - sparse, sparse -
+       diag.
+
+       * dSparse.cc (operator +): Define overrides for real diag +
+       sparse.
+       (operator -): Define overrides for real diag - sparse, sparse -
+       diag.
+
+       * CSparse.h (operator +): Declare overrides for complex and real
+       combinations of diag + sparse.
+       (operator -): Declare overrides for complex and real combinations
+       of diag - sparse, sparse - diag.
+
+       * CSparse.cc (operator +): Define overrides for complex and real
+       combinations of diag + sparse.
+       (operator -): Define overrides for complex and real combinations
+       of diag - sparse, sparse - diag.
+
 2009-03-08  Jason Riedy  <address@hidden>
 
        * Sparse-diag-op-defs.h (octave_impl::do_mul_dm_sm)
diff --git a/liboctave/Sparse-diag-op-defs.h b/liboctave/Sparse-diag-op-defs.h
--- a/liboctave/Sparse-diag-op-defs.h
+++ b/liboctave/Sparse-diag-op-defs.h
@@ -25,6 +25,8 @@
 
 namespace octave_impl {
 
+// Matrix multiplication
+
 template <typename RT, typename DM, typename SM, typename UnaryOp>
 RT do_mul_dm_sm (const DM& d, const SM& a, UnaryOp conj_fn)
 {
@@ -112,6 +114,116 @@
   return do_mul_sm_dm<RT> (a, d, octave_impl::identity_val<typename 
DM::element_type>());
 }
 
+// Matrix addition
+
+template <typename RT, typename SM, typename DM, typename OpA, typename OpD>
+RT inner_do_add_sm_dm (const SM& a, const DM& d, OpA opa, OpD opd)
+{
+  using std::min;
+  const octave_idx_type nr = d.rows ();
+  const octave_idx_type nc = d.cols ();
+  const octave_idx_type n = min (nr, nc);
+
+  const octave_idx_type a_nr = a.rows ();
+  const octave_idx_type a_nc = a.cols ();
+
+  const octave_idx_type nz = a.nnz ();
+  const typename SM::element_type zero = typename SM::element_type ();
+  const typename DM::element_type dzero = typename DM::element_type ();
+  RT r (a_nr, a_nc, nz + n);
+  octave_idx_type k = 0;
+
+  for (octave_idx_type j = 0; j < nc; ++j)
+    {
+      OCTAVE_QUIT;
+      const octave_idx_type colend = a.cidx (j+1);
+      bool found_diag = false;
+      r.xcidx (j) = k;
+      for (octave_idx_type k_src = a.cidx (j); k_src < colend; ++k_src, ++k)
+       {
+         const octave_idx_type i = a.ridx (k_src);
+         r.xridx (k) = i;
+         if (i != j)
+           r.xdata (k) = opa (a.data (k_src));
+         else
+           {
+             r.xdata (k) = opa (a.data (k_src)) + opd (d.dgelem (j));
+             found_diag = true;
+           }
+       }
+      if (!found_diag)
+       {
+         r.xridx (k) = j;
+         r.xdata (k) = opd (d.dgelem (j));
+         ++k;
+       }
+    }
+  r.xcidx (nc) = k;
+
+  r.maybe_compress (true);
+  return r;
+}
+
+template <typename RT, typename DM, typename SM>
+RT do_commutative_add_dm_sm (const DM& d, const SM& a)
+{
+  // Extra function to ensure this is only emitted once.
+  return inner_do_add_sm_dm<RT> (a, d,
+                                octave_impl::identity_val<typename 
SM::element_type> (),
+                                octave_impl::identity_val<typename 
DM::element_type> ());
+}
+
+template <typename RT, typename DM, typename SM>
+RT do_add_dm_sm (const DM& d, const SM& a)
+{
+  if (a.rows () != d.rows () || a.cols () != d.cols ())
+    {
+      gripe_nonconformant ("operator +", d.rows (), d.cols (), a.rows (), 
a.cols ());
+      return RT ();
+    }
+  else
+    return do_commutative_add_dm_sm<RT> (d, a);
+}
+
+template <typename RT, typename DM, typename SM>
+RT do_sub_dm_sm (const DM& d, const SM& a)
+{
+  if (a.rows () != d.rows () || a.cols () != d.cols ())
+    {
+      gripe_nonconformant ("operator -", d.rows (), d.cols (), a.rows (), 
a.cols ());
+      return RT ();
+    }
+  else
+    return inner_do_add_sm_dm<RT> (a, d, std::negate<typename 
SM::element_type> (),
+                                  octave_impl::identity_val<typename 
DM::element_type> ());
+}
+
+template <typename RT, typename SM, typename DM>
+RT do_add_sm_dm (const SM& a, const DM& d)
+{
+  if (a.rows () != d.rows () || a.cols () != d.cols ())
+    {
+      gripe_nonconformant ("operator +", a.rows (), a.cols (), d.rows (), 
d.cols ());
+      return RT ();
+    }
+  else
+    return do_commutative_add_dm_sm<RT> (d, a);
+}
+
+template <typename RT, typename SM, typename DM>
+RT do_sub_sm_dm (const SM& a, const DM& d)
+{
+  if (a.rows () != d.rows () || a.cols () != d.cols ())
+    {
+      gripe_nonconformant ("operator -", a.rows (), a.cols (), d.rows (), 
d.cols ());
+      return RT ();
+    }
+  else
+    return inner_do_add_sm_dm<RT> (a, d,
+                                  octave_impl::identity_val<typename 
SM::element_type> (),
+                                  std::negate<typename DM::element_type> ());
+}
+
 } // namespace octave_impl
 
 #endif // octave_sparse_diag_op_defs_h
diff --git a/liboctave/dSparse.cc b/liboctave/dSparse.cc
--- a/liboctave/dSparse.cc
+++ b/liboctave/dSparse.cc
@@ -7725,6 +7725,30 @@
   return operator * (a, d);
 }
 
+SparseMatrix
+operator + (const DiagMatrix& d, const SparseMatrix& a)
+{
+  return octave_impl::do_add_dm_sm<SparseMatrix> (d, a);
+}
+
+SparseMatrix
+operator - (const DiagMatrix& d, const SparseMatrix& a)
+{
+  return octave_impl::do_sub_dm_sm<SparseMatrix> (d, a);
+}
+
+SparseMatrix
+operator + (const SparseMatrix& a, const DiagMatrix& d)
+{
+  return octave_impl::do_add_sm_dm<SparseMatrix> (a, d);
+}
+
+SparseMatrix
+operator - (const SparseMatrix& a, const DiagMatrix& d)
+{
+  return octave_impl::do_sub_sm_dm<SparseMatrix> (a, d);
+}
+
 // FIXME -- it would be nice to share code among the min/max
 // functions below.
 
diff --git a/liboctave/dSparse.h b/liboctave/dSparse.h
--- a/liboctave/dSparse.h
+++ b/liboctave/dSparse.h
@@ -456,6 +456,11 @@
 extern OCTAVE_API SparseMatrix operator * (const SparseMatrix&, const 
DiagMatrix&);
 extern OCTAVE_API SparseMatrix mul_trans (const SparseMatrix&, const 
DiagMatrix&);
 
+extern OCTAVE_API SparseMatrix operator + (const DiagMatrix&, const 
SparseMatrix&);
+extern OCTAVE_API SparseMatrix operator + (const SparseMatrix&, const 
DiagMatrix&);
+extern OCTAVE_API SparseMatrix operator - (const DiagMatrix&, const 
SparseMatrix&);
+extern OCTAVE_API SparseMatrix operator - (const SparseMatrix&, const 
DiagMatrix&);
+
 extern OCTAVE_API SparseMatrix min (double d, const SparseMatrix& m);
 extern OCTAVE_API SparseMatrix min (const SparseMatrix& m, double d);
 extern OCTAVE_API SparseMatrix min (const SparseMatrix& a, const SparseMatrix& 
b);
diff --git a/src/ChangeLog b/src/ChangeLog
--- a/src/ChangeLog
+++ b/src/ChangeLog
@@ -1,3 +1,25 @@
+2009-03-09  Jason Riedy  <address@hidden>
+
+       * OPERATORS/op-dm-sm.cc (add_dm_sm): Octave binding for diag + sparse.
+       (sub_dm_sm): Octave binding for diag - sparse.
+       (add_sm_dm): Octave binding for diag + sparse.
+       (sub_sm_dm): Octave binding for diag - sparse.
+       (install_dm_sm_ops): Install above bindings.
+
+       * OPERATORS/op-dm-scm.cc (add_cdm_sm): Octave binding for diag + sparse.
+       (add_dm_scm): Octave binding for diag + sparse.
+       (add_cdm_scm): Octave binding for diag + sparse.
+       (sub_cdm_sm): Octave binding for diag - sparse.
+       (sub_dm_scm): Octave binding for diag - sparse.
+       (sub_cdm_csm): Octave binding for diag - sparse.
+       (add_sm_cdm): Octave binding for diag + sparse.
+       (add_scm_dm): Octave binding for diag + sparse.
+       (add_scm_cdm): Octave binding for diag + sparse.
+       (sub_sm_cdm): Octave binding for diag - sparse.
+       (sub_scm_dm): Octave binding for diag - sparse.
+       (sub_scm_cdm): Octave binding for diag - sparse.
+       (install_dm_scm_ops): Install above bindings.
+
 2009-03-08  Jason Riedy  <address@hidden>
 
        * sparse-xdiv.h (xleftdiv): Declare overrides for
diff --git a/src/OPERATORS/op-dm-scm.cc b/src/OPERATORS/op-dm-scm.cc
--- a/src/OPERATORS/op-dm-scm.cc
+++ b/src/OPERATORS/op-dm-scm.cc
@@ -228,6 +228,192 @@
                   typ);
 }
 
+DEFBINOP (add_dm_scm, diag_matrix, sparse_complex_matrix)
+{
+  CAST_BINOP_ARGS (const octave_diag_matrix&, const 
octave_sparse_complex_matrix&);
+
+  if (v2.rows() == 1 && v2.columns() == 1)
+    // If v2 is a scalar in disguise, return a diagonal matrix rather than
+    // a sparse matrix.
+    {
+      std::complex<double> d = v2.complex_value ();
+
+      return octave_value (v1.diag_matrix_value () + d);
+    }
+  else if (v1.rows() == 1 && v1.columns() == 1)
+    // If v1 is a scalar in disguise, don't bother with further dispatching.
+    {
+      double d = v1.scalar_value ();
+
+      return octave_value (d + v1.sparse_complex_matrix_value ());
+    }
+  else
+    {
+
+      MatrixType typ = v2.matrix_type ();
+      SparseComplexMatrix ret = v1.diag_matrix_value () + 
v2.sparse_complex_matrix_value ();
+      octave_value out = octave_value (ret);
+      typ.mark_as_unsymmetric ();
+      out.matrix_type (typ);
+      return out;
+    }
+}
+
+DEFBINOP (add_cdm_sm, complex_diag_matrix, sparse_matrix)
+{
+  CAST_BINOP_ARGS (const octave_complex_diag_matrix&, const 
octave_sparse_matrix&);
+
+  if (v2.rows() == 1 && v2.columns() == 1)
+    // If v2 is a scalar in disguise, return a diagonal matrix rather than
+    // a sparse matrix.
+    {
+      double d = v2.scalar_value ();
+
+      return octave_value (v1.complex_diag_matrix_value () + d);
+    }
+  else if (v1.rows() == 1 && v1.columns() == 1)
+    // If v1 is a scalar in disguise, don't bother with further dispatching.
+    {
+      std::complex<double> d = v1.complex_value ();
+
+      return octave_value (d + v1.sparse_matrix_value ());
+    }
+  else
+    {
+
+      MatrixType typ = v2.matrix_type ();
+      SparseComplexMatrix ret = v1.complex_diag_matrix_value () + 
v2.sparse_matrix_value ();
+      octave_value out = octave_value (ret);
+      typ.mark_as_unsymmetric ();
+      out.matrix_type (typ);
+      return out;
+    }
+}
+
+DEFBINOP (add_cdm_scm, complex_diag_matrix, sparse_complex_matrix)
+{
+  CAST_BINOP_ARGS (const octave_complex_diag_matrix&, const 
octave_sparse_complex_matrix&);
+
+  if (v2.rows() == 1 && v2.columns() == 1)
+    // If v2 is a scalar in disguise, return a diagonal matrix rather than
+    // a sparse matrix.
+    {
+      std::complex<double> d = v2.complex_value ();
+
+      return octave_value (v1.complex_diag_matrix_value () + d);
+    }
+  else if (v1.rows() == 1 && v1.columns() == 1)
+    // If v1 is a scalar in disguise, don't bother with further dispatching.
+    {
+      std::complex<double> d = v1.scalar_value ();
+
+      return octave_value (d + v1.sparse_complex_matrix_value ());
+    }
+  else
+    {
+
+      MatrixType typ = v2.matrix_type ();
+      SparseComplexMatrix ret = v1.complex_diag_matrix_value () + 
v2.sparse_complex_matrix_value ();
+      octave_value out = octave_value (ret);
+      typ.mark_as_unsymmetric ();
+      out.matrix_type (typ);
+      return out;
+    }
+}
+
+DEFBINOP (sub_dm_scm, diag_matrix, sparse_complex_matrix)
+{
+  CAST_BINOP_ARGS (const octave_diag_matrix&, const 
octave_sparse_complex_matrix&);
+
+  if (v2.rows() == 1 && v2.columns() == 1)
+    // If v2 is a scalar in disguise, return a diagonal matrix rather than
+    // a sparse matrix.
+    {
+      std::complex<double> d = v2.complex_value ();
+
+      return octave_value (v1.diag_matrix_value () + (-d));
+    }
+  else if (v1.rows() == 1 && v1.columns() == 1)
+    // If v1 is a scalar in disguise, don't bother with further dispatching.
+    {
+      double d = v1.scalar_value ();
+
+      return octave_value (d - v1.sparse_complex_matrix_value ());
+    }
+  else
+    {
+
+      MatrixType typ = v2.matrix_type ();
+      SparseComplexMatrix ret = v1.diag_matrix_value () - 
v2.sparse_complex_matrix_value ();
+      octave_value out = octave_value (ret);
+      typ.mark_as_unsymmetric ();
+      out.matrix_type (typ);
+      return out;
+    }
+}
+
+DEFBINOP (sub_cdm_sm, complex_diag_matrix, sparse_matrix)
+{
+  CAST_BINOP_ARGS (const octave_complex_diag_matrix&, const 
octave_sparse_matrix&);
+
+  if (v2.rows() == 1 && v2.columns() == 1)
+    // If v2 is a scalar in disguise, return a diagonal matrix rather than
+    // a sparse matrix.
+    {
+      double d = v2.scalar_value ();
+
+      return octave_value (v1.complex_diag_matrix_value () + (-d));
+    }
+  else if (v1.rows() == 1 && v1.columns() == 1)
+    // If v1 is a scalar in disguise, don't bother with further dispatching.
+    {
+      std::complex<double> d = v1.complex_value ();
+
+      return octave_value (d - v1.sparse_matrix_value ());
+    }
+  else
+    {
+
+      MatrixType typ = v2.matrix_type ();
+      SparseComplexMatrix ret = v1.complex_diag_matrix_value () - 
v2.sparse_matrix_value ();
+      octave_value out = octave_value (ret);
+      typ.mark_as_unsymmetric ();
+      out.matrix_type (typ);
+      return out;
+    }
+}
+
+DEFBINOP (sub_cdm_scm, complex_diag_matrix, sparse_complex_matrix)
+{
+  CAST_BINOP_ARGS (const octave_complex_diag_matrix&, const 
octave_sparse_complex_matrix&);
+
+  if (v2.rows() == 1 && v2.columns() == 1)
+    // If v2 is a scalar in disguise, return a diagonal matrix rather than
+    // a sparse matrix.
+    {
+      std::complex<double> d = v2.complex_value ();
+
+      return octave_value (v1.complex_diag_matrix_value () + (-d));
+    }
+  else if (v1.rows() == 1 && v1.columns() == 1)
+    // If v1 is a scalar in disguise, don't bother with further dispatching.
+    {
+      std::complex<double> d = v1.scalar_value ();
+
+      return octave_value (d - v1.sparse_complex_matrix_value ());
+    }
+  else
+    {
+
+      MatrixType typ = v2.matrix_type ();
+      SparseComplexMatrix ret = v1.complex_diag_matrix_value () - 
v2.sparse_complex_matrix_value ();
+      octave_value out = octave_value (ret);
+      typ.mark_as_unsymmetric ();
+      out.matrix_type (typ);
+      return out;
+    }
+}
+
 // sparse matrix by diagonal matrix ops
 
 DEFBINOP (mul_scm_dm, sparse_complex_matrix, diag_matrix)
@@ -448,6 +634,192 @@
     }
 }
 
+DEFBINOP (add_sm_cdm, sparse_matrix, complex_diag_matrix)
+{
+  CAST_BINOP_ARGS (const octave_sparse_matrix&, const 
octave_complex_diag_matrix&);
+
+  if (v2.rows() == 1 && v2.columns() == 1)
+    // If v2 is a scalar in disguise, return a diagonal matrix rather than
+    // a sparse matrix.
+    {
+      std::complex<double> d = v2.complex_value ();
+
+      return octave_value (v1.sparse_matrix_value () + d);
+    }
+  else if (v1.rows() == 1 && v1.columns() == 1)
+    // If v1 is a scalar in disguise, don't bother with further dispatching.
+    {
+      double d = v1.scalar_value ();
+
+      return octave_value (d + v1.complex_diag_matrix_value ());
+    }
+  else
+    {
+
+      MatrixType typ = v2.matrix_type ();
+      SparseComplexMatrix ret = v1.sparse_matrix_value () + 
v2.complex_diag_matrix_value ();
+      octave_value out = octave_value (ret);
+      typ.mark_as_unsymmetric ();
+      out.matrix_type (typ);
+      return out;
+    }
+}
+
+DEFBINOP (add_scm_dm, sparse_complex_matrix, diag_matrix)
+{
+  CAST_BINOP_ARGS (const octave_sparse_complex_matrix&, const 
octave_diag_matrix&);
+
+  if (v2.rows() == 1 && v2.columns() == 1)
+    // If v2 is a scalar in disguise, return a diagonal matrix rather than
+    // a sparse matrix.
+    {
+      double d = v2.scalar_value ();
+
+      return octave_value (v1.sparse_complex_matrix_value () + d);
+    }
+  else if (v1.rows() == 1 && v1.columns() == 1)
+    // If v1 is a scalar in disguise, don't bother with further dispatching.
+    {
+      std::complex<double> d = v1.complex_value ();
+
+      return octave_value (d + v1.diag_matrix_value ());
+    }
+  else
+    {
+
+      MatrixType typ = v2.matrix_type ();
+      SparseComplexMatrix ret = v1.sparse_complex_matrix_value () + 
v2.diag_matrix_value ();
+      octave_value out = octave_value (ret);
+      typ.mark_as_unsymmetric ();
+      out.matrix_type (typ);
+      return out;
+    }
+}
+
+DEFBINOP (add_scm_cdm, sparse_complex_matrix, complex_diag_matrix)
+{
+  CAST_BINOP_ARGS (const octave_sparse_complex_matrix&, const 
octave_complex_diag_matrix&);
+
+  if (v2.rows() == 1 && v2.columns() == 1)
+    // If v2 is a scalar in disguise, return a diagonal matrix rather than
+    // a sparse matrix.
+    {
+      std::complex<double> d = v2.complex_value ();
+
+      return octave_value (v1.sparse_complex_matrix_value () + d);
+    }
+  else if (v1.rows() == 1 && v1.columns() == 1)
+    // If v1 is a scalar in disguise, don't bother with further dispatching.
+    {
+      std::complex<double> d = v1.scalar_value ();
+
+      return octave_value (d + v1.complex_diag_matrix_value ());
+    }
+  else
+    {
+
+      MatrixType typ = v2.matrix_type ();
+      SparseComplexMatrix ret = v1.sparse_complex_matrix_value () + 
v2.complex_diag_matrix_value ();
+      octave_value out = octave_value (ret);
+      typ.mark_as_unsymmetric ();
+      out.matrix_type (typ);
+      return out;
+    }
+}
+
+DEFBINOP (sub_sm_cdm, sparse_matrix, complex_diag_matrix)
+{
+  CAST_BINOP_ARGS (const octave_sparse_matrix&, const 
octave_complex_diag_matrix&);
+
+  if (v2.rows() == 1 && v2.columns() == 1)
+    // If v2 is a scalar in disguise, return a diagonal matrix rather than
+    // a sparse matrix.
+    {
+      std::complex<double> d = v2.complex_value ();
+
+      return octave_value (v1.sparse_matrix_value () + (-d));
+    }
+  else if (v1.rows() == 1 && v1.columns() == 1)
+    // If v1 is a scalar in disguise, don't bother with further dispatching.
+    {
+      double d = v1.scalar_value ();
+
+      return octave_value (d - v1.complex_diag_matrix_value ());
+    }
+  else
+    {
+
+      MatrixType typ = v2.matrix_type ();
+      SparseComplexMatrix ret = v1.sparse_matrix_value () - 
v2.complex_diag_matrix_value ();
+      octave_value out = octave_value (ret);
+      typ.mark_as_unsymmetric ();
+      out.matrix_type (typ);
+      return out;
+    }
+}
+
+DEFBINOP (sub_scm_dm, sparse_complex_matrix, diag_matrix)
+{
+  CAST_BINOP_ARGS (const octave_sparse_complex_matrix&, const 
octave_diag_matrix&);
+
+  if (v2.rows() == 1 && v2.columns() == 1)
+    // If v2 is a scalar in disguise, return a diagonal matrix rather than
+    // a sparse matrix.
+    {
+      double d = v2.scalar_value ();
+
+      return octave_value (v1.sparse_complex_matrix_value () + (-d));
+    }
+  else if (v1.rows() == 1 && v1.columns() == 1)
+    // If v1 is a scalar in disguise, don't bother with further dispatching.
+    {
+      std::complex<double> d = v1.complex_value ();
+
+      return octave_value (d - v1.diag_matrix_value ());
+    }
+  else
+    {
+
+      MatrixType typ = v2.matrix_type ();
+      SparseComplexMatrix ret = v1.sparse_complex_matrix_value () - 
v2.diag_matrix_value ();
+      octave_value out = octave_value (ret);
+      typ.mark_as_unsymmetric ();
+      out.matrix_type (typ);
+      return out;
+    }
+}
+
+DEFBINOP (sub_scm_cdm, sparse_complex_matrix, complex_diag_matrix)
+{
+  CAST_BINOP_ARGS (const octave_sparse_complex_matrix&, const 
octave_complex_diag_matrix&);
+
+  if (v2.rows() == 1 && v2.columns() == 1)
+    // If v2 is a scalar in disguise, return a diagonal matrix rather than
+    // a sparse matrix.
+    {
+      std::complex<double> d = v2.complex_value ();
+
+      return octave_value (v1.sparse_complex_matrix_value () + (-d));
+    }
+  else if (v1.rows() == 1 && v1.columns() == 1)
+    // If v1 is a scalar in disguise, don't bother with further dispatching.
+    {
+      std::complex<double> d = v1.scalar_value ();
+
+      return octave_value (d - v1.complex_diag_matrix_value ());
+    }
+  else
+    {
+
+      MatrixType typ = v2.matrix_type ();
+      SparseComplexMatrix ret = v1.sparse_complex_matrix_value () - 
v2.complex_diag_matrix_value ();
+      octave_value out = octave_value (ret);
+      typ.mark_as_unsymmetric ();
+      out.matrix_type (typ);
+      return out;
+    }
+}
+
 void
 install_dm_scm_ops (void)
 {
@@ -477,6 +849,15 @@
   INSTALL_BINOP (op_ldiv, octave_complex_diag_matrix, 
octave_sparse_complex_matrix,
                 ldiv_cdm_scm);
 
+  INSTALL_BINOP (op_add, octave_diag_matrix, octave_sparse_complex_matrix, 
add_dm_scm);
+  INSTALL_BINOP (op_add, octave_complex_diag_matrix, octave_sparse_matrix, 
add_cdm_sm);
+  INSTALL_BINOP (op_add, octave_complex_diag_matrix, 
octave_sparse_complex_matrix,
+                add_cdm_scm);
+  INSTALL_BINOP (op_sub, octave_diag_matrix, octave_sparse_complex_matrix, 
sub_dm_scm);
+  INSTALL_BINOP (op_sub, octave_complex_diag_matrix, octave_sparse_matrix, 
sub_cdm_sm);
+  INSTALL_BINOP (op_sub, octave_complex_diag_matrix, 
octave_sparse_complex_matrix,
+                sub_cdm_scm);
+
   INSTALL_BINOP (op_mul, octave_sparse_complex_matrix, octave_diag_matrix,
                 mul_scm_dm);
   INSTALL_BINOP (op_mul_trans, octave_sparse_complex_matrix, 
octave_diag_matrix,
@@ -501,4 +882,11 @@
   INSTALL_BINOP (op_div, octave_sparse_complex_matrix, octave_diag_matrix, 
div_scm_dm);
   INSTALL_BINOP (op_div, octave_sparse_matrix, octave_complex_diag_matrix, 
div_sm_cdm);
   INSTALL_BINOP (op_div, octave_sparse_complex_matrix, 
octave_complex_diag_matrix, div_scm_cdm);
+
+  INSTALL_BINOP (op_add, octave_sparse_complex_matrix, octave_diag_matrix, 
add_scm_dm);
+  INSTALL_BINOP (op_add, octave_sparse_matrix, octave_complex_diag_matrix, 
add_sm_cdm);
+  INSTALL_BINOP (op_add, octave_sparse_complex_matrix, 
octave_complex_diag_matrix, add_scm_cdm);
+  INSTALL_BINOP (op_sub, octave_sparse_complex_matrix, octave_diag_matrix, 
sub_scm_dm);
+  INSTALL_BINOP (op_sub, octave_sparse_matrix, octave_complex_diag_matrix, 
sub_sm_cdm);
+  INSTALL_BINOP (op_sub, octave_sparse_complex_matrix, 
octave_complex_diag_matrix, sub_scm_cdm);
 }
diff --git a/src/OPERATORS/op-dm-sm.cc b/src/OPERATORS/op-dm-sm.cc
--- a/src/OPERATORS/op-dm-sm.cc
+++ b/src/OPERATORS/op-dm-sm.cc
@@ -76,6 +76,68 @@
   return xleftdiv (v1.diag_matrix_value (), v2.sparse_matrix_value (), typ);
 }
 
+DEFBINOP (add_dm_sm, diag_matrix, sparse_matrix)
+{
+  CAST_BINOP_ARGS (const octave_diag_matrix&, const octave_sparse_matrix&);
+
+  if (v2.rows() == 1 && v2.columns() == 1)
+    // If v2 is a scalar in disguise, return a diagonal matrix rather than
+    // a sparse matrix.
+    {
+      double d = v2.scalar_value ();
+
+      return octave_value (v1.diag_matrix_value () + d);
+    }
+  else if (v1.rows() == 1 && v1.columns() == 1)
+    // If v1 is a scalar in disguise, don't bother with further dispatching.
+    {
+      double d = v1.scalar_value ();
+
+      return octave_value (d + v1.sparse_matrix_value ());
+    }
+  else
+    {
+
+      MatrixType typ = v2.matrix_type ();
+      SparseMatrix ret = v1.diag_matrix_value () + v2.sparse_matrix_value ();
+      octave_value out = octave_value (ret);
+      typ.mark_as_unsymmetric ();
+      out.matrix_type (typ);
+      return out;
+    }
+}
+
+DEFBINOP (sub_dm_sm, diag_matrix, sparse_matrix)
+{
+  CAST_BINOP_ARGS (const octave_diag_matrix&, const octave_sparse_matrix&);
+
+  if (v2.rows() == 1 && v2.columns() == 1)
+    // If v2 is a scalar in disguise, return a diagonal matrix rather than
+    // a sparse matrix.
+    {
+      double d = v2.scalar_value ();
+
+      return octave_value (v1.diag_matrix_value () - d);
+    }
+  else if (v1.rows() == 1 && v1.columns() == 1)
+    // If v1 is a scalar in disguise, don't bother with further dispatching.
+    {
+      double d = v1.scalar_value ();
+
+      return octave_value (d - v1.sparse_matrix_value ());
+    }
+  else
+    {
+
+      MatrixType typ = v2.matrix_type ();
+      SparseMatrix ret = v1.diag_matrix_value () - v2.sparse_matrix_value ();
+      octave_value out = octave_value (ret);
+      typ.mark_as_unsymmetric ();
+      out.matrix_type (typ);
+      return out;
+    }
+}
+
 // sparse matrix by diagonal matrix ops
 
 DEFBINOP (mul_sm_dm, sparse_matrix, diag_matrix)
@@ -129,6 +191,68 @@
     }
 }
 
+DEFBINOP (add_sm_dm, sparse_matrix, diag_matrix)
+{
+  CAST_BINOP_ARGS (const octave_sparse_matrix&, const octave_diag_matrix&);
+
+  if (v1.rows() == 1 && v1.columns() == 1)
+    // If v1 is a scalar in disguise, return a diagonal matrix rather than
+    // a sparse matrix.
+    {
+      double d = v1.scalar_value ();
+
+      return octave_value (d + v2.diag_matrix_value ());
+    }
+  else if (v2.rows() == 1 && v2.columns() == 1)
+    // If v2 is a scalar in disguise, don't bother with further dispatching.
+    {
+      double d = v2.scalar_value ();
+
+      return octave_value (v1.sparse_matrix_value () + d);
+    }
+  else
+    {
+
+      MatrixType typ = v1.matrix_type ();
+      SparseMatrix ret = v1.sparse_matrix_value () + v2.diag_matrix_value ();
+      octave_value out = octave_value (ret);
+      typ.mark_as_unsymmetric ();
+      out.matrix_type (typ);
+      return out;
+    }
+}
+
+DEFBINOP (sub_sm_dm, sparse_matrix, diag_matrix)
+{
+  CAST_BINOP_ARGS (const octave_sparse_matrix&, const octave_diag_matrix&);
+
+  if (v1.rows() == 1 && v1.columns() == 1)
+    // If v1 is a scalar in disguise, return a diagonal matrix rather than
+    // a sparse matrix.
+    {
+      double d = v1.scalar_value ();
+
+      return octave_value (d - v2.diag_matrix_value ());
+    }
+  else if (v2.rows() == 1 && v2.columns() == 1)
+    // If v2 is a scalar in disguise, don't bother with further dispatching.
+    {
+      double d = v2.scalar_value ();
+
+      return octave_value (v1.sparse_matrix_value () - d);
+    }
+  else
+    {
+
+      MatrixType typ = v1.matrix_type ();
+      SparseMatrix ret = v1.sparse_matrix_value () - v2.diag_matrix_value ();
+      octave_value out = octave_value (ret);
+      typ.mark_as_unsymmetric ();
+      out.matrix_type (typ);
+      return out;
+    }
+}
+
 void
 install_dm_sm_ops (void)
 {
@@ -139,6 +263,8 @@
   INSTALL_BINOP (op_herm_mul, octave_diag_matrix, octave_sparse_matrix,
                 mul_dm_sm);
 
+  INSTALL_BINOP (op_add, octave_diag_matrix, octave_sparse_matrix, add_dm_sm);
+  INSTALL_BINOP (op_sub, octave_diag_matrix, octave_sparse_matrix, sub_dm_sm);
   INSTALL_BINOP (op_ldiv, octave_diag_matrix, octave_sparse_matrix, 
ldiv_dm_sm);
 
   INSTALL_BINOP (op_mul, octave_sparse_matrix, octave_diag_matrix,
@@ -148,5 +274,7 @@
   INSTALL_BINOP (op_herm_mul, octave_sparse_matrix, octave_diag_matrix,
                 mul_sm_dm);
 
+  INSTALL_BINOP (op_add, octave_sparse_matrix, octave_diag_matrix, add_sm_dm);
+  INSTALL_BINOP (op_sub, octave_sparse_matrix, octave_diag_matrix, sub_sm_dm);
   INSTALL_BINOP (op_div, octave_sparse_matrix, octave_diag_matrix, div_sm_dm);
 }
diff --git a/test/ChangeLog b/test/ChangeLog
--- a/test/ChangeLog
+++ b/test/ChangeLog
@@ -1,3 +1,7 @@
+2009-03-09  Jason Riedy  <address@hidden>
+
+       * test_diag_perm.m: Add tests for diag + sparse.
+
 2009-03-08  Jason Riedy  <address@hidden>
 
        * test_diag_perm.m: Add tests for inverse scaling and sparse structure.
diff --git a/test/test_diag_perm.m b/test/test_diag_perm.m
--- a/test/test_diag_perm.m
+++ b/test/test_diag_perm.m
@@ -182,3 +182,28 @@
 %! scalefact = rand (1, n-2) + I () * rand(1, n-2);
 %! Dc = diag (scalefact, n-2, n);
 %! assert (full (A / Dc), full(A) / Dc)
+
+## adding sparse and diagonal stays sparse
+%!test
+%! n = 9;
+%! A = sprand (n, n, .5);
+%! D = 2 * eye (n);
+%! assert (typeinfo (A + D), "sparse matrix")
+%! assert (typeinfo (A - D), "sparse matrix")
+%! D = D * I () + D;
+%! assert (typeinfo (A - D), "sparse complex matrix")
+%! A = A * I () + A;
+%! assert (typeinfo (D - A), "sparse complex matrix")
+
+## adding sparse and diagonal stays sparse
+%!test
+%! n = 9;
+%! A = sprand (n, n, .5);
+%! D = 2 * eye (n);
+%! assert (full (A + D), full (A) + D)
+%! assert (full (A - D), full (A) - D)
+%! D = D * I () + D;
+%! assert (full (D + A), D + full (A))
+%! A = A * I () + A;
+%! A(6, 4) = nan ();
+%! assert (full (D - A), D - full (A))


reply via email to

[Prev in Thread] Current Thread [Next in Thread]