diff --git a/Makefile.am b/Makefile.am index 271f431ff..fe3fc232f 100644 --- a/Makefile.am +++ b/Makefile.am @@ -249,7 +249,6 @@ core_SOURCES = \ src/core/vector.h \ src/core/recarray.h \ src/core/matrix.h \ - src/core/matrix.imp \ src/core/integer.cc \ src/core/integer.h \ src/core/rational.cc \ diff --git a/src/core/matrix.cc b/src/core/matrix.cc index 6b5fd4496..6f25b8fd3 100644 --- a/src/core/matrix.cc +++ b/src/core/matrix.cc @@ -20,8 +20,7 @@ // Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. // -#include "core.h" -#include "matrix.imp" +#include "matrix.h" namespace Gambit { @@ -30,9 +29,4 @@ template class Matrix; template class Matrix; template class Matrix; -template Vector operator*(const Vector &, const Matrix &); -template Vector operator*(const Vector &, const Matrix &); -template Vector operator*(const Vector &, const Matrix &); -template Vector operator*(const Vector &, const Matrix &); - } // end namespace Gambit diff --git a/src/core/matrix.h b/src/core/matrix.h index 20d66f482..c07af4f80 100644 --- a/src/core/matrix.h +++ b/src/core/matrix.h @@ -23,8 +23,11 @@ #ifndef GAMBIT_CORE_MATRIX_H #define GAMBIT_CORE_MATRIX_H +#include + #include "recarray.h" #include "vector.h" +#include "rational.h" namespace Gambit { @@ -34,81 +37,565 @@ class SingularMatrixException final : public std::runtime_error { ~SingularMatrixException() noexcept override = default; }; -template Vector operator*(const Vector &, const Matrix &); - -template class Matrix : public RectArray { - friend Vector operator* <>(const Vector &, const Matrix &); +/// @brief Dense rectangular matrix with arbitrary integer index ranges. +/// +/// @tparam T Scalar element type. Must support arithmetic operations, +/// comparison with zero, and abs(T). +/// +/// @note Inverse and Determinant use legacy Gaussian elimination algorithms. +/// They are known not to be numerically optimal for near-singular matrices. +/// Current behaviour is temporarily preserved for historical compatibility. +template class Matrix { + RectArray m_data; public: /// @name Lifecycle - //@{ - Matrix(); - Matrix(unsigned int rows, unsigned int cols); - Matrix(unsigned int rows, unsigned int cols, int minrows); - Matrix(int rl, int rh, int cl, int ch); - Matrix(const Matrix &); - ~Matrix() override; - - Matrix &operator=(const Matrix &); - Matrix &operator=(const T &); - //@} - - /// @name Extracting rows and columns - //@{ - bool IsSquare() const + /// Constructors, assignment, and destruction + ///@{ + Matrix() = default; + Matrix(unsigned int rows, unsigned int cols) : m_data(rows, cols) {} + Matrix(unsigned int rows, unsigned int cols, int minrows) + : m_data(minrows, minrows + rows - 1, 1, cols) { - return this->MinRow() == this->MinCol() && this->MaxRow() == this->MaxCol(); } - Vector Row(int) const; - Vector Column(int) const; - //@} + Matrix(int rl, int rh, int cl, int ch) : m_data(rl, rh, cl, ch) {} + Matrix(const Matrix &) = default; + Matrix(Matrix &&) noexcept = default; + ~Matrix() = default; + + Matrix &operator=(const Matrix &) = default; + Matrix &operator=(Matrix &&) noexcept = default; + Matrix &operator=(const T &); + ///@} + + /// @brief Access matrix element at (row, column) with bounds checking + /// + /// @param r Row index + /// @param c Column index + /// @throws std::out_of_range on invalid index. + T &operator()(int r, int c) { return m_data(r, c); } + /// @copydoc operator(int, int) + const T &operator()(int r, int c) const { return m_data(r, c); } + + /// @brief Lowest valid row index + int MinRow() const { return m_data.MinRow(); } + /// @brief Highest valid row index + int MaxRow() const { return m_data.MaxRow(); } + /// @brief Lowest valid column index + int MinCol() const { return m_data.MinCol(); } + /// @brief Highest valid column index + int MaxCol() const { return m_data.MaxCol(); } + /// @brief Number of rows in the matrix + int NumRows() const { return m_data.NumRows(); } + /// @brief Number of columns in the matrix + int NumColumns() const { return m_data.NumColumns(); } + + /// @brief Check if the matrix is a square matrix (num rows == num columns) + bool IsSquare() const { return MinRow() == MinCol() && MaxRow() == MaxCol(); } + + /// @name Row and column helpers + ///@{ + void SwitchRows(int i, int j) { m_data.SwitchRows(i, j); } + template void GetColumn(int j, V &) const; + template void SetColumn(int j, const V &); + template void GetRow(int row, V &) const; + template void SetRow(int row, const V &); + + /// @brief Test whether a vector conforms to the matrix row shape + template bool ConformsToRow(const V &v) const { return m_data.ConformsToRow(v); } + /// @brief Test whether a vector conforms to the matrix column shape + template bool ConformsToColumn(const V &v) const { return m_data.ConformsToColumn(v); } + /// @brief Test whether another matrix conforms to the shape of this matrix + bool ConformsTo(const Matrix &M) const { return m_data.ConformsTo(M.m_data); } + ///@} /// @name Comparison operators - //@{ - bool operator==(const Matrix &) const; - bool operator!=(const Matrix &) const; + /// Element-wise comparisons + ///@{ + bool operator==(const Matrix &) const; + bool operator!=(const Matrix &M) const { return !(*this == M); } bool operator==(const T &) const; - bool operator!=(const T &) const; - //@} - - /// @name Additive operators - //@{ - Matrix operator+(const Matrix &) const; - Matrix operator-(const Matrix &) const; - Matrix &operator+=(const Matrix &); - Matrix &operator-=(const Matrix &); - - Matrix operator-(); - //@} - - /// @name Multiplicative operators - //@{ - /// "in-place" column multiply - void CMultiply(const Vector &, Vector &) const; - /// "in-place" row (transposed) multiply - void RMultiply(const Vector &, Vector &) const; - Matrix operator*(const Matrix &) const; - Vector operator*(const Vector &) const; - Matrix operator*(const T &) const; - Matrix &operator*=(const T &); - - Matrix operator/(const T &) const; - Matrix &operator/=(const T &); - //@ + bool operator!=(const T &c) const { return !(*this == c); } + ///@} - /// @name Other operations - //@{ - Matrix Transpose() const; - /// Set matrix to identity matrix - void MakeIdent(); - void Pivot(int, int); - //@} + /// @name Arithmetic operators + /// Element-wise and scalar arithmetic + /// + /// All matrix-matrix operations require identical row and column bounds and + /// operate component-by-component. Scalar operations apply uniformly to all + /// elements. + ///@{ + Matrix operator+(const Matrix &M) const + { + Matrix tmp(*this); + tmp += M; + return tmp; + } + Matrix operator-(const Matrix &M) const + { + Matrix tmp(*this); + tmp -= M; + return tmp; + } + Matrix &operator+=(const Matrix &); + Matrix &operator-=(const Matrix &); + + Matrix operator-() const; + Matrix operator*(const T &c) const + { + Matrix tmp(*this); + tmp *= c; + return tmp; + } + Matrix &operator*=(const T &); + + Matrix operator/(const T &c) const + { + Matrix tmp(*this); + tmp /= c; + return tmp; + } + Matrix &operator/=(const T &); + ///@} + + /// @name Linear algebra operations + /// Some primitives for doing linear algebra. + ///@{ + + /// @brief Multiply a matrix by a column vector + /// + /// Computes p_output = (*this) * p_input, where @p p_input is interpreted + /// as a column vector and @p p_output receives the resulting column vector. + /// + /// @param p_input Input column vector + /// @param p_output Output column vector + /// @throws DimensionException if dimensions are incompatible + void CMultiply(const Vector &p_input, Vector &p_output) const; + + /// @brief Multiply a row vector by the matrix + /// + /// Computes p_output = p_input * (*this), where @p p_input is interpreted + /// as a row vector and @p p_output receives the resulting row vector. + /// + /// @param p_input Input row vector + /// @param p_output Output row vector + /// @throws DimensionException if dimensions are incompatible + void RMultiply(const Vector &p_input, Vector &p_output) const; + + /// @brief Multiply matrix by a column vector. + /// + /// Equivalent to CMultiply(v, result) + /// + /// @param v Input column vector + /// @return Resulting column vector + /// @throws DimensionException if dimensions are incompatible + Vector operator*(const Vector &v) const; + + /// @brief Matrix-matrix multiplication + /// + /// Computes the product (*this) * M; the number of columns of this matrix + /// must equal the number of rows in @param M . + /// + /// @param M Matrix to multiply this with + /// @return Resulting matrix + /// @throws DimensionException if dimensions are incompatible + Matrix operator*(const Matrix &M) const; + + ///@ + + /// @name Other operations + ///@{ + Matrix Transpose() const; Matrix Inverse() const; T Determinant() const; + ///@} }; -template Vector operator*(const Vector &, const Matrix &); +template Matrix &Matrix::operator=(const T &c) +{ + std::fill(m_data.elements_begin(), m_data.elements_end(), c); + return *this; +} + +// ---------------------------------------------------------------------------- +// Implementation of element-wise operations +// ---------------------------------------------------------------------------- + +template bool Matrix::operator==(const Matrix &M) const +{ + if (!this->ConformsTo(M)) { + throw DimensionException(); + } + return std::equal(m_data.elements_begin(), m_data.elements_end(), M.m_data.elements_begin()); +} + +template bool Matrix::operator==(const T &c) const +{ + return std::all_of(m_data.elements_begin(), m_data.elements_end(), + [&c](const auto &v) { return v == c; }); +} + +template Matrix &Matrix::operator+=(const Matrix &M) +{ + if (!this->ConformsTo(M)) { + throw DimensionException(); + } + std::transform(m_data.elements_begin(), m_data.elements_end(), M.m_data.elements_begin(), + m_data.elements_begin(), std::plus<>()); + return *this; +} + +template Matrix &Matrix::operator-=(const Matrix &M) +{ + if (!this->ConformsTo(M)) { + throw DimensionException(); + } + std::transform(m_data.elements_begin(), m_data.elements_end(), M.m_data.elements_begin(), + m_data.elements_begin(), std::minus<>()); + return *this; +} + +template Matrix Matrix::operator-() const +{ + Matrix tmp(*this); + std::transform(tmp.m_data.elements_begin(), tmp.m_data.elements_end(), + tmp.m_data.elements_begin(), std::negate<>()); + return tmp; +} + +template Matrix &Matrix::operator*=(const T &c) +{ + std::transform(m_data.elements_begin(), m_data.elements_end(), m_data.elements_begin(), + [&c](const T &v) { return v * c; }); + return *this; +} + +template Matrix &Matrix::operator/=(const T &c) +{ + if (c == T{0}) { + throw ZeroDivideException(); + } + std::transform(m_data.elements_begin(), m_data.elements_end(), m_data.elements_begin(), + [&c](const T &v) { return v / c; }); + return *this; +} + +// ---------------------------------------------------------------------------- +// Implementation of row/column access +// ---------------------------------------------------------------------------- + +template template void Matrix::GetColumn(int col, V &v) const +{ + if (col < MinCol() || col > MaxCol()) { + throw std::out_of_range("Index out of range in Matrix::GetColumn"); + } + if (!ConformsToColumn(v)) { + throw DimensionException(); + } + for (int i = MinRow(); i <= MaxRow(); ++i) { + v[i] = (*this)(i, col); + } +} + +template template void Matrix::SetColumn(int col, const V &v) +{ + if (col < MinCol() || col > MaxCol()) { + throw std::out_of_range("Index out of range in Matrix::SetColumn"); + } + if (!ConformsToColumn(v)) { + throw DimensionException(); + } + for (int i = MinRow(); i <= MaxRow(); ++i) { + (*this)(i, col) = v[i]; + } +} + +template template void Matrix::GetRow(int row, V &v) const +{ + if (row < MinRow() || row > MaxRow()) { + throw std::out_of_range("Index out of range in Matrix::GetRow"); + } + if (!ConformsToRow(v)) { + throw DimensionException(); + } + for (int j = MinCol(); j <= MaxCol(); ++j) { + v[j] = (*this)(row, j); + } +} + +template template void Matrix::SetRow(int row, const V &v) +{ + if (row < MinRow() || row > MaxRow()) { + throw std::out_of_range("Index out of range in Matrix::SetRow"); + } + if (!ConformsToRow(v)) { + throw DimensionException(); + } + for (int j = MinCol(); j <= MaxCol(); ++j) { + (*this)(row, j) = v[j]; + } +} + +// ---------------------------------------------------------------------------- +// Implementation of linear algebra concepts +// ---------------------------------------------------------------------------- + +template void Matrix::CMultiply(const Vector &p_input, Vector &p_output) const +{ + if (!this->ConformsToRow(p_input) || !this->ConformsToColumn(p_output)) { + throw DimensionException(); + } + for (int i = MinRow(); i <= MaxRow(); ++i) { + auto row = m_data.GetRowView(i); + p_output[i] = std::inner_product(row.begin(), row.end(), p_input.begin(), T{0}); + } +} + +template void Matrix::RMultiply(const Vector &p_input, Vector &p_output) const +{ + if (!this->ConformsToColumn(p_input) || !this->ConformsToRow(p_output)) { + throw DimensionException(); + } + + p_output = T{0}; + for (int i = MinRow(); i <= MaxRow(); ++i) { + auto row = m_data.GetRowView(i); + const T k = p_input[i]; + auto dst = p_output.begin(); + for (auto it = row.begin(); it != row.end(); ++it, ++dst) { + *dst += (*it) * k; + } + } +} + +template Vector Matrix::operator*(const Vector &v) const +{ + if (!this->ConformsToRow(v)) { + throw DimensionException(); + } + Vector tmp(MinRow(), MaxRow()); + CMultiply(v, tmp); + return tmp; +} + +template Matrix Matrix::operator*(const Matrix &M) const +{ + if (MinCol() != M.MinRow() || MaxCol() != M.MaxRow()) { + throw DimensionException(); + } + Matrix tmp(MinRow(), MaxRow(), M.MinCol(), M.MaxCol()); + for (int i = MinRow(); i <= MaxRow(); ++i) { + auto row = m_data.GetRowView(i); + for (int j = M.MinCol(); j <= M.MaxCol(); ++j) { + auto col = M.m_data.GetColumnView(j); + tmp(i, j) = std::inner_product(row.begin(), row.end(), col.begin(), T{0}); + } + } + return tmp; +} + +template Matrix Matrix::Transpose() const +{ + Matrix tmp(MinCol(), MaxCol(), MinRow(), MaxRow()); + + for (int i = MinRow(); i <= MaxRow(); ++i) { + auto src_row = m_data.GetRowView(i); + auto dst_col = tmp.m_data.GetColumnView(i); + + auto src = src_row.begin(); + auto dst = dst_col.begin(); + + for (; src != src_row.end(); ++src, ++dst) { + *dst = *src; + } + } + + return tmp; +} + +// ---------------------------------------------------------------------------- +// Implementation of additional operations +// ---------------------------------------------------------------------------- + +template Matrix Matrix::Inverse() const +{ + if (!IsSquare()) { + throw DimensionException(); + } + const int rmin = MinRow(); + const int rmax = MaxRow(); + const int cmin = MinCol(); + const int cmax = MaxCol(); + using Gambit::abs; + + Matrix copy(*this); + Matrix inv(rmin, rmax, cmin, cmax); + + inv = T{0}; + + // initialize inverse matrix and prescale row vectors + for (int i = rmin; i <= rmax; ++i) { + auto copy_row = copy.m_data.GetRowView(i); + auto inv_row = inv.m_data.GetRowView(i); + + T max = maximize_function(copy_row, [](const T &v) { return abs(v); }); + if (max == T{0}) { + throw SingularMatrixException(); + } + const T scale = T{1} / max; + for (auto &v : copy_row) { + v *= scale; + } + inv_row[i] = scale; + } + + for (int i = cmin; i <= cmax; ++i) { + // find pivot row + auto col_i = copy.m_data.GetColumnView(i); + T max = abs(col_i[i]); + int row = i; + for (int j = i + 1; j <= rmax; ++j) { + const T v = abs(col_i[j]); + if (v > max) { + max = v; + row = j; + } + } + + if (max <= T{0}) { + throw SingularMatrixException(); + } + + copy.SwitchRows(i, row); + inv.SwitchRows(i, row); + // scale pivot row + T factor = T{1} / copy(i, i); + auto copy_row = copy.m_data.GetRowView(i); + auto inv_row = inv.m_data.GetRowView(i); + auto copy_it = copy_row.begin(); + auto inv_it = inv_row.begin(); + for (; copy_it != copy_row.end(); ++copy_it, ++inv_it) { + *copy_it *= factor; + *inv_it *= factor; + } + + // reduce other rows + auto pivot_copy_row = copy.m_data.GetRowView(i); + auto pivot_inv_row = inv.m_data.GetRowView(i); + + for (int j = rmin; j <= rmax; ++j) { + if (j == i) { + continue; + } + auto row_copy = copy.m_data.GetRowView(j); + auto row_inv = inv.m_data.GetRowView(j); + const T mult = row_copy[i]; + auto pivot_copy_it = pivot_copy_row.begin(); + auto pivot_inv_it = pivot_inv_row.begin(); + auto row_copy_it = row_copy.begin(); + auto row_inv_it = row_inv.begin(); + + for (; pivot_copy_it != pivot_copy_row.end(); + ++pivot_copy_it, ++pivot_inv_it, ++row_copy_it, ++row_inv_it) { + *row_copy_it -= (*pivot_copy_it) * mult; + *row_inv_it -= (*pivot_inv_it) * mult; + } + } + } + + return inv; +} + +template T Matrix::Determinant() const +{ + if (!IsSquare()) { + throw DimensionException(); + } + const int rmin = MinRow(); + const int rmax = MaxRow(); + using Gambit::abs; + + Matrix M(*this); + + for (int row = rmin; row <= rmax; ++row) { + // Experience (as of 3/22/99) suggests that, in the interest of + // numerical stability, it might be best to do Gaussian + // elimination with respect to the row (of those feasible) + // whose entry has the largest absolute value. + int swap_row = row; + T max = abs(M(row, row)); + for (int i = row + 1; i <= rmax; ++i) { + const T v = abs(M(i, row)); + if (v > max) { + max = v; + swap_row = i; + } + } + + if (swap_row != row) { + M.SwitchRows(row, swap_row); + auto pivot_row = M.m_data.GetRowView(row); + for (auto &v : pivot_row) { + v = -v; + } + } + + if (M(row, row) == T{0}) { + return T{0}; + } + // now do row operations to clear the row'th column + // below the diagonal + auto pivot_row = M.m_data.GetRowView(row); + for (int row1 = row + 1; row1 <= rmax; ++row1) { + auto elim_row = M.m_data.GetRowView(row1); + const T factor = -elim_row[row] / pivot_row[row]; + auto pivot_it = pivot_row.begin(); + auto elim_it = elim_row.begin(); + for (; pivot_it != pivot_row.end(); ++pivot_it, ++elim_it) { + *elim_it += (*pivot_it) * factor; + } + } + } + + // finally we multiply the diagonal elements + T det = T{1}; + for (int row = rmin; row <= rmax; ++row) { + det *= M(row, row); + } + return det; +} + +// ---------------------------------------------------------------------------- +// Implementation of operators +// ---------------------------------------------------------------------------- + +/// @brief Multiple a row vector by a matrix +/// +/// Computes v * M, where @param v is interpreted as a row vector +/// +/// @param v The row vector +/// @param M The matrix to multiply with +/// @throws DimensionException if dimensions are incompatible +/// @sa Matrix::RMultiply +template Vector operator*(const Vector &v, const Matrix &M) +{ + if (!M.ConformsToColumn(v)) { + throw DimensionException(); + } + Vector tmp(M.MinCol(), M.MaxCol()); + M.RMultiply(v, tmp); + return tmp; +} + +// ---------------------------------------------------------------------------- +// Explicit instantiations +// ---------------------------------------------------------------------------- + +extern template class Matrix; +extern template class Matrix; +extern template class Matrix; +extern template class Matrix; } // end namespace Gambit diff --git a/src/core/matrix.imp b/src/core/matrix.imp index fdddbccf1..e69de29bb 100644 --- a/src/core/matrix.imp +++ b/src/core/matrix.imp @@ -1,585 +0,0 @@ -// -// This file is part of Gambit -// Copyright (c) 1994-2026, The Gambit Project (https://www.gambit-project.org) -// -// FILE: src/core/matrix.imp -// Implementation of matrix method functions -// -// This program is free software; you can redistribute it and/or modify -// it under the terms of the GNU General Public License as published by -// the Free Software Foundation; either version 2 of the License, or -// (at your option) any later version. -// -// This program is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU General Public License for more details. -// -// You should have received a copy of the GNU General Public License -// along with this program; if not, write to the Free Software -// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. -// - -#include "matrix.h" - -namespace Gambit { - -//------------------------------------------------------------------------- -// Matrix: Constructors, destructors, constructive operators -//------------------------------------------------------------------------- - -template Matrix::Matrix() = default; - -template -Matrix::Matrix(unsigned int rows, unsigned int cols) : RectArray(rows, cols) -{ -} - -template -Matrix::Matrix(unsigned int rows, unsigned int cols, int minrows) - : RectArray(minrows, minrows + rows - 1, 1, cols) -{ -} - -template Matrix::Matrix(int rl, int rh, int cl, int ch) : RectArray(rl, rh, cl, ch) -{ -} - -template Matrix::Matrix(const Matrix &M) : RectArray(M) {} - -template Matrix::~Matrix() = default; - -template Matrix &Matrix::operator=(const Matrix &M) -{ - RectArray::operator=(M); - return *this; -} - -template Matrix &Matrix::operator=(const T &c) -{ - for (int i = this->minrow; i <= this->maxrow; i++) { - for (int j = this->mincol; j <= this->maxcol; j++) { - (*this)(i, j) = c; - } - } - return *this; -} - -template Matrix Matrix::operator-() -{ - Matrix tmp(this->minrow, this->maxrow, this->mincol, this->maxcol); - for (int i = this->minrow; i <= this->maxrow; i++) { - for (int j = this->mincol; j <= this->maxcol; j++) { - tmp(i, j) = -(*this)(i, j); - } - } - return tmp; -} - -//------------------------------------------------------------------------- -// Matrix: Additive operators -//------------------------------------------------------------------------- - -template Matrix Matrix::operator+(const Matrix &M) const -{ - if (!this->CheckBounds(M)) { - throw DimensionException(); - } - - const Matrix tmp(this->minrow, this->maxrow, this->mincol, this->maxcol); - for (int i = this->minrow; i <= this->maxrow; i++) { - const T *src1 = this->data[i] + this->mincol; - const T *src2 = M.data[i] + this->mincol; - // NOLINTBEGIN(misc-const-correctness) - T *dst = tmp.data[i] + this->mincol; - // NOLINTEND(misc-const-correctness) - int j = this->maxcol - this->mincol + 1; - while (j--) { - *(dst++) = *(src1++) + *(src2++); - } - // assert((dst - 1) == tmp.data[i] + this->maxcol ); - } - return tmp; -} - -template Matrix Matrix::operator-(const Matrix &M) const -{ - if (!this->CheckBounds(M)) { - throw DimensionException(); - } - - const Matrix tmp(this->minrow, this->maxrow, this->mincol, this->maxcol); - for (int i = this->minrow; i <= this->maxrow; i++) { - const T *src1 = this->data[i] + this->mincol; - const T *src2 = M.data[i] + this->mincol; - // NOLINTBEGIN(misc-const-correctness) - T *dst = tmp.data[i] + this->mincol; - // NOLINTEND(misc-const-correctness) - int j = this->maxcol - this->mincol + 1; - while (j--) { - *(dst++) = *(src1++) - *(src2++); - } - // assert((dst - 1) == tmp.data[i] + this->maxcol); - } - return tmp; -} - -template Matrix &Matrix::operator+=(const Matrix &M) -{ - if (!this->CheckBounds(M)) { - throw DimensionException(); - } - - for (int i = this->minrow; i <= this->maxrow; i++) { - const T *src = M.data[i] + this->mincol; - // NOLINTBEGIN(misc-const-correctness) - T *dst = this->data[i] + this->mincol; - // NOLINTEND(misc-const-correctness) - int j = this->maxcol - this->mincol + 1; - while (j--) { - *(dst++) += *(src++); - } - } - return (*this); -} - -template Matrix &Matrix::operator-=(const Matrix &M) -{ - if (!this->CheckBounds(M)) { - throw DimensionException(); - } - - for (int i = this->minrow; i <= this->maxrow; i++) { - const T *src = M.data[i] + this->mincol; - // NOLINTBEGIN(misc-const-correctness) - T *dst = this->data[i] + this->mincol; - // NOLINTEND(misc-const-correctness) - int j = this->maxcol - this->mincol + 1; - while (j--) { - *(dst++) -= *(src++); - } - // assert((dst - 1) == this->data[i] + this->maxcol); - } - return (*this); -} - -//------------------------------------------------------------------------- -// Matrix: Multiplicative operators -//------------------------------------------------------------------------- - -template void Matrix::CMultiply(const Vector &in, Vector &out) const -{ - if (!this->CheckRow(in) || !this->CheckColumn(out)) { - throw DimensionException(); - } - - for (int i = this->minrow; i <= this->maxrow; i++) { - T sum = (T)0; - - const T *src1 = this->data[i] + this->mincol; - auto src2 = in.begin(); - int j = this->maxcol - this->mincol + 1; - while (j--) { - sum += *(src1++) * *(src2++); - } - out[i] = sum; - } -} - -template Matrix Matrix::operator*(const Matrix &M) const -{ - if (this->mincol != M.minrow || this->maxcol != M.maxrow) { - throw DimensionException(); - } - - Matrix tmp(this->minrow, this->maxrow, M.mincol, M.maxcol); - Vector column(M.minrow, M.maxrow); - Vector result(this->minrow, this->maxrow); - for (int j = M.mincol; j <= M.maxcol; j++) { - M.GetColumn(j, column); - CMultiply(column, result); - tmp.SetColumn(j, result); - } - return tmp; -} - -template Vector Matrix::operator*(const Vector &v) const -{ - if (!this->CheckRow(v)) { - throw DimensionException(); - } - - Vector tmp(this->minrow, this->maxrow); - CMultiply(v, tmp); - return tmp; -} - -template void Matrix::RMultiply(const Vector &in, Vector &out) const -{ - if (!this->CheckColumn(in) || !this->CheckRow(out)) { - throw DimensionException(); - } - - out = (T)0; - for (int i = this->minrow; i <= this->maxrow; i++) { - T k = in[i]; - - const T *src = this->data[i] + this->mincol; - auto dst = out.begin(); - int j = this->maxcol - this->mincol + 1; - while (j--) { - *(dst++) += *(src++) * k; - } - // assert(src - 1 == this->data[i] + this->maxcol); - } -} - -// transposed (row) vector*matrix multiplication operator -// a friend function of Matrix -template Vector operator*(const Vector &v, const Matrix &M) -{ - if (!M.CheckColumn(v)) { - throw DimensionException(); - } - Vector tmp(M.MinCol(), M.MaxCol()); - M.RMultiply(v, tmp); - return tmp; -} - -template Matrix Matrix::operator*(const T &s) const -{ - const Matrix tmp(this->minrow, this->maxrow, this->mincol, this->maxcol); - for (int i = this->minrow; i <= this->maxrow; i++) { - const T *src = this->data[i] + this->mincol; - // NOLINTBEGIN(misc-const-correctness) - T *dst = tmp.data[i] + this->mincol; - // NOLINTEND(misc-const-correctness) - int j = this->maxcol - this->mincol + 1; - while (j--) { - *(dst++) = *(src++) * s; - } - // assert((src - 1) == this->data[i] + this->maxcol); - } - return tmp; -} - -template Matrix &Matrix::operator*=(const T &s) -{ - for (int i = this->minrow; i <= this->maxrow; i++) { - // NOLINTBEGIN(misc-const-correctness) - T *dst = this->data[i] + this->mincol; - // NOLINTEND(misc-const-correctness) - int j = this->maxcol - this->mincol + 1; - while (j--) { - *(dst++) *= s; - } - } - return (*this); -} - -template Matrix Matrix::operator/(const T &s) const -{ - if (s == (T)0) { - throw ZeroDivideException(); - } - - const Matrix tmp(this->minrow, this->maxrow, this->mincol, this->maxcol); - for (int i = this->minrow; i <= this->maxrow; i++) { - const T *src = this->data[i] + this->mincol; - // NOLINTBEGIN(misc-const-correctness) - T *dst = tmp.data[i] + this->mincol; - // NOLINTEND(misc-const-correctness) - int j = this->maxcol - this->mincol + 1; - while (j--) { - *(dst++) = *(src++) / s; - } - } - return tmp; -} - -template Matrix &Matrix::operator/=(const T &s) -{ - if (s == (T)0) { - throw ZeroDivideException(); - } - - for (int i = this->minrow; i <= this->maxrow; i++) { - // NOLINTBEGIN(misc-const-correctness) - T *dst = this->data[i] + this->mincol; - // NOLINTEND(misc-const-correctness) - int j = this->maxcol - this->mincol + 1; - while (j--) { - *(dst++) /= s; - } - } - return (*this); -} - -//------------------------------------------------------------------------- -// Matrix: Transpose -//------------------------------------------------------------------------- - -template Matrix Matrix::Transpose() const -{ - Matrix tmp(this->mincol, this->maxcol, this->minrow, this->maxrow); - - for (int i = this->minrow; i <= this->maxrow; i++) { - for (int j = this->mincol; j <= this->maxcol; j++) { - tmp(j, i) = (*this)(i, j); - } - } - - return tmp; -} - -//------------------------------------------------------------------------- -// Matrix: Comparison operators -//------------------------------------------------------------------------- - -template bool Matrix::operator==(const Matrix &M) const -{ - if (!this->CheckBounds(M)) { - throw DimensionException(); - } - - for (int i = this->minrow; i <= this->maxrow; i++) { - // inner loop - const T *src1 = M.data[i] + this->mincol; - const T *src2 = this->data[i] + this->mincol; - int j = this->maxcol - this->mincol + 1; - while (j--) { - if (*(src1++) != *(src2++)) { - return false; - } - } - // assert(src1 - 1 == M.data[i] + this->maxcol); - } - return true; -} - -template bool Matrix::operator!=(const Matrix &M) const { return !(*this == M); } - -template bool Matrix::operator==(const T &s) const -{ - for (int i = this->minrow; i <= this->maxrow; i++) { - const T *src = this->data[i] + this->mincol; - int j = this->maxcol - this->mincol + 1; - while (j--) { - if (*(src++) != s) { - return false; - } - } - // assert(src - 1 == this->data[i] + this->maxcol); - } - return true; -} - -template bool Matrix::operator!=(const T &s) const { return !(*this == s); } - -// Information - -template Vector Matrix::Row(int i) const -{ - Vector answer(this->mincol, this->maxcol); - for (int j = this->mincol; j <= this->maxcol; j++) { - answer[j] = (*this)(i, j); - } - return answer; -} - -template Vector Matrix::Column(int j) const -{ - Vector answer(this->minrow, this->maxrow); - for (int i = this->minrow; i <= this->maxrow; i++) { - answer[i] = (*this)(i, j); - } - return answer; -} - -// more complex functions - -template void Matrix::MakeIdent() -{ - if (!IsSquare()) { - throw DimensionException(); - } - for (int i = this->minrow; i <= this->maxrow; i++) { - for (int j = this->mincol; j <= this->maxcol; j++) { - if (i == j) { - (*this)(i, j) = (T)1; - } - else { - (*this)(i, j) = (T)0; - } - } - } -} - -template void Matrix::Pivot(int row, int col) -{ - if (!this->CheckRow(row) || !this->CheckColumn(col)) { - throw std::out_of_range("Index out of range in Matrix::Pivot"); - } - if (this->data[row][col] == (T)0) { - throw ZeroDivideException(); - } - - T mult = (T)1 / this->data[row][col]; - for (int j = this->mincol; j <= this->maxcol; j++) { - this->data[row][j] *= mult; - } - for (int i = this->minrow; i <= this->maxrow; i++) { - if (i != row) { - mult = this->data[i][col]; - - // inner loop - const T *src = this->data[row] + this->mincol; - // NOLINTBEGIN(misc-const-correctness) - T *dst = this->data[i] + this->mincol; - // NOLINTEND(misc-const-correctness) - int j = this->maxcol - this->mincol + 1; - while (j--) { - *(dst++) -= *(src++) * mult; - } - // assert( dst-1 == this->data[i] + this->maxcol ); // debug - // end inner loop - } - } -} - -template Matrix Matrix::Inverse() const -{ - if (!IsSquare()) { - throw DimensionException(); - } - - Matrix copy(*this); - Matrix inv(this->MaxRow(), this->MaxRow()); - - // initialize inverse matrix and prescale row vectors - for (int i = this->MinRow(); i <= this->MaxRow(); i++) { - T max = (T)0; - for (int j = this->MinCol(); j <= this->MaxCol(); j++) { - T abs = copy(i, j); - if (abs < (T)0) { - abs = -abs; - } - if (abs > max) { - max = abs; - } - } - - if (max == (T)0) { - throw SingularMatrixException(); - } - - T scale = (T)1 / max; - for (int j = this->MinCol(); j <= this->MaxCol(); j++) { - copy(i, j) *= scale; - if (i == j) { - inv(i, j) = scale; - } - else { - inv(i, j) = (T)0; - } - } - } - - for (int i = this->MinCol(); i <= this->MaxCol(); i++) { - // find pivot row - T max = copy(i, i); - if (max < (T)0) { - max = -max; - } - int row = i; - for (int j = i + 1; j <= this->MaxRow(); j++) { - T abs = copy(j, i); - if (abs < (T)0) { - abs = -abs; - } - if (abs > max) { - max = abs; - row = j; - } - } - - if (max <= (T)0) { - throw SingularMatrixException(); - } - - copy.SwitchRows(i, row); - inv.SwitchRows(i, row); - // scale pivot row - T factor = (T)1 / copy(i, i); - for (int k = this->MinCol(); k <= this->MaxCol(); k++) { - copy(i, k) *= factor; - inv(i, k) *= factor; - } - - // reduce other rows - for (int j = this->MinRow(); j <= this->MaxRow(); j++) { - if (j != i) { - T mult = copy(j, i); - for (int k = this->MinCol(); k <= this->MaxCol(); k++) { - copy(j, k) -= copy(i, k) * mult; - inv(j, k) -= inv(i, k) * mult; - } - } - } - } - - return inv; -} - -template T Matrix::Determinant() const -{ - if (!IsSquare()) { - throw DimensionException(); - } - - T factor = (T)1; - Matrix M(*this); - - for (int row = this->MinRow(); row <= this->MaxRow(); row++) { - - // Experience (as of 3/22/99) suggests that, in the interest of - // numerical stability, it might be best to do Gaussian - // elimination with respect to the row (of those feasible) - // whose entry has the largest absolute value. - int swap_row = row; - for (int i = row + 1; i <= this->MaxRow(); i++) { - if (abs(M.data[i][row]) > abs(M.data[swap_row][row])) { - swap_row = i; - } - } - - if (swap_row != row) { - M.SwitchRows(row, swap_row); - for (int j = this->MinCol(); j <= this->MaxCol(); j++) { - M.data[row][j] *= (T)-1; - } - } - - if (M.data[row][row] == (T)0) { - return (T)0; - } - - // now do row operations to clear the row'th column - // below the diagonal - for (int row1 = row + 1; row1 <= this->MaxRow(); row1++) { - factor = -M.data[row1][row] / M.data[row][row]; - for (int i = this->MinCol(); i <= this->MaxCol(); i++) { - M.data[row1][i] += M.data[row][i] * factor; - } - } - } - - // finally we multiply the diagonal elements - T det = (T)1; - for (int row = this->MinRow(); row <= this->MaxRow(); row++) { - det *= M.data[row][row]; - } - return det; -} - -} // end namespace Gambit diff --git a/src/core/recarray.h b/src/core/recarray.h index bddb1b9ca..9d09c49ca 100644 --- a/src/core/recarray.h +++ b/src/core/recarray.h @@ -23,6 +23,7 @@ #ifndef GAMBIT_CORE_RECARRAY_H #define GAMBIT_CORE_RECARRAY_H +#include #include "util.h" namespace Gambit { @@ -30,54 +31,162 @@ namespace Gambit { /// This class implements a rectangular (two-dimensional) array template class RectArray { protected: - int minrow, maxrow, mincol, maxcol; - T **data; + int m_minrow, m_maxrow, m_mincol, m_maxcol; + std::vector m_storage; + + size_t row_stride() const { return static_cast(m_maxcol - m_mincol + 1); } + + size_t index(int r, int c) const + { + return static_cast(r - m_minrow) * row_stride() + static_cast(c - m_mincol); + } + +public: + class RowView { + RectArray *m_array; + int m_row; + + public: + class iterator { + RowView *m_view; + int m_col; + + public: + using iterator_category = std::forward_iterator_tag; + using value_type = T; + using difference_type = int; + using pointer = T *; + using reference = T &; + + iterator(RowView *p_view, int p_col) : m_view(p_view), m_col(p_col) {} + + reference operator*() const { return (*m_view)[m_col]; } + + iterator &operator++() + { + ++m_col; + return *this; + } + + bool operator==(const iterator &p_other) const { return m_col == p_other.m_col; } + bool operator!=(const iterator &p_other) const { return !(*this == p_other); } + }; + + RowView(RectArray &p_array, int p_row) : m_array(&p_array), m_row(p_row) + { + if (!m_array->CheckRow(m_row)) { + throw std::out_of_range("RowView"); + } + } + T &operator[](int c) { return (*m_array)(m_row, c); } + const T &operator[](int c) const { return (*m_array)(m_row, c); } + int MinIndex() const { return m_array->MinCol(); } + int MaxIndex() const { return m_array->MaxCol(); } + iterator begin() { return iterator(this, MinIndex()); } + iterator begin() const { return iterator(const_cast(this), MinIndex()); } + iterator end() { return iterator(this, MaxIndex() + 1); } + iterator end() const { return iterator(const_cast(this), MaxIndex() + 1); } + }; + + class ColumnView { + RectArray *m_array; + int m_col; + + public: + class iterator { + ColumnView *m_view; + int m_row; + + public: + using iterator_category = std::forward_iterator_tag; + using value_type = T; + using difference_type = int; + using pointer = T *; + using reference = T &; + + iterator(ColumnView *p_view, int p_row) : m_view(p_view), m_row(p_row) {} + + reference operator*() const { return (*m_view)[m_row]; } + + iterator &operator++() + { + ++m_row; + return *this; + } + + bool operator==(const iterator &p_other) const { return m_row == p_other.m_row; } + bool operator!=(const iterator &p_other) const { return !(*this == p_other); } + }; + + ColumnView(RectArray &p_array, int p_col) : m_array(&p_array), m_col(p_col) + { + if (!m_array->CheckColumn(m_col)) { + throw std::out_of_range("ColumnView"); + } + } + T &operator[](int r) { return (*m_array)(r, m_col); } + const T &operator[](int r) const { return (*m_array)(r, m_col); } + int MinIndex() const { return m_array->MinRow(); } + int MaxIndex() const { return m_array->MaxRow(); } + iterator begin() { return iterator(this, MinIndex()); } + iterator begin() const { return iterator(const_cast(this), MinIndex()); } + iterator end() { return iterator(this, MaxIndex() + 1); } + iterator end() const { return iterator(const_cast(this), MaxIndex() + 1); } + }; + + /// @name Lifecycle + //@{ + RectArray() : m_minrow(1), m_maxrow(0), m_mincol(1), m_maxcol(0), m_storage() {} + RectArray(const size_t nrows, const size_t ncols) : RectArray(1, nrows, 1, ncols) {} + RectArray(const int minrow, const int maxrow, const int mincol, const int maxcol) + : m_minrow(minrow), m_maxrow(maxrow), m_mincol(mincol), m_maxcol(maxcol), + m_storage((maxrow >= minrow && maxcol >= mincol) + ? (maxrow - minrow + 1) * (maxcol - mincol + 1) + : 0) + { + } + RectArray(const RectArray &) = default; + RectArray(RectArray &&) noexcept = default; + ~RectArray() = default; + + RectArray &operator=(const RectArray &) = default; + RectArray &operator=(RectArray &&) noexcept = default; + //@} /// @name Range checking functions; returns true only if valid index/size //@{ + /// check array for same row and column boundaries + bool ConformsTo(const RectArray &m) const + { + return (m_minrow == m.m_minrow && m_maxrow == m.m_maxrow && m_mincol == m.m_mincol && + m_maxcol == m.m_maxcol); + } /// check for correct row index - bool CheckRow(int row) const { return (minrow <= row && row <= maxrow); } + bool CheckRow(const int row) const { return (m_minrow <= row && row <= m_maxrow); } /// check row vector for correct column boundaries - template bool CheckRow(const Vector &v) const + template bool ConformsToRow(const V &v) const { - return (v.front_index() == mincol && v.back_index() == maxcol); + return v.front_index() == m_mincol && v.back_index() == m_maxcol; } /// check for correct column index - bool CheckColumn(int col) const { return (mincol <= col && col <= maxcol); } + bool CheckColumn(const int col) const { return (m_mincol <= col && col <= m_maxcol); } /// check column vector for correct row boundaries - template bool CheckColumn(const Vector &v) const + template bool ConformsToColumn(const V &v) const { - return (v.front_index() == minrow && v.back_index() == maxrow); + return (v.front_index() == m_minrow && v.back_index() == m_maxrow); } /// check row and column indices - bool Check(int row, int col) const { return CheckRow(row) && CheckColumn(col); } - /// check matrix for same row and column boundaries - bool CheckBounds(const RectArray &m) const - { - return (minrow == m.minrow && maxrow == m.maxrow && mincol == m.mincol && maxcol == m.maxcol); - } - //@} + bool Check(const int row, const int col) const { return CheckRow(row) && CheckColumn(col); } -public: - /// @name Lifecycle - //@{ - RectArray() : minrow(1), maxrow(0), mincol(1), maxcol(0), data(nullptr) {} - RectArray(unsigned int nrows, unsigned int ncols); - RectArray(int minr, int maxr, int minc, int maxc); - RectArray(const RectArray &); - virtual ~RectArray(); - - RectArray &operator=(const RectArray &); //@} - /// @name General data access //@{ - size_t NumRows() const { return maxrow - minrow + 1; } - size_t NumColumns() const { return maxcol - mincol + 1; } - int MinRow() const { return minrow; } - int MaxRow() const { return maxrow; } - int MinCol() const { return mincol; } - int MaxCol() const { return maxcol; } + size_t NumRows() const { return m_maxrow - m_minrow + 1; } + size_t NumColumns() const { return m_maxcol - m_mincol + 1; } + int MinRow() const { return m_minrow; } + int MaxRow() const { return m_maxrow; } + int MinCol() const { return m_mincol; } + int MaxCol() const { return m_maxcol; } //@} /// @name Indexing operations @@ -87,14 +196,14 @@ template class RectArray { if (!Check(r, c)) { throw std::out_of_range("Index out of range in RectArray"); } - return data[r][c]; + return m_storage[index(r, c)]; } const T &operator()(int r, int c) const { if (!Check(r, c)) { throw std::out_of_range("Index out of range in RectArray"); } - return data[r][c]; + return m_storage[index(r, c)]; } //@} @@ -111,92 +220,40 @@ template class RectArray { if (!Check(i, j)) { throw std::out_of_range("Index out of range in RectArray"); } - std::swap(data[i], data[j]); + if (i == j) { + return; + } + + const auto stride = row_stride(); + const size_t ai = index(i, m_mincol); + const size_t aj = index(j, m_mincol); + for (size_t k = 0; k < stride; ++k) { + std::swap(m_storage[ai + k], m_storage[aj + k]); + } } + RowView GetRowView(int r) { return RowView(*this, r); } + RowView GetRowView(int r) const { return RowView(const_cast(*this), r); } + ColumnView GetColumnView(int c) { return ColumnView(*this, c); } + ColumnView GetColumnView(int c) const { return ColumnView(const_cast(*this), c); } template void GetRow(int, Vector &) const; template void GetColumn(int, Vector &) const; template void SetColumn(int, const Vector &); //@} -}; -//------------------------------------------------------------------------ -// RectArray: Constructors, destructor, constructive operators -//------------------------------------------------------------------------ - -template -RectArray::RectArray(unsigned int rows, unsigned int cols) - : minrow(1), maxrow(rows), mincol(1), maxcol(cols), - data((rows > 0) ? new T *[maxrow] - 1 : nullptr) -{ - for (int i = 1; i <= maxrow; data[i++] = (cols > 0) ? new T[maxcol] - 1 : nullptr) - ; -} - -template -RectArray::RectArray(int minr, int maxr, int minc, int maxc) - : minrow(minr), maxrow(maxr), mincol(minc), maxcol(maxc), - data((maxrow >= minrow) ? new T *[maxrow - minrow + 1] - minrow : nullptr) -{ - for (int i = minrow; i <= maxrow; - data[i++] = (maxcol - mincol + 1) ? new T[maxcol - mincol + 1] - mincol : nullptr) - ; -} - -template -RectArray::RectArray(const RectArray &a) - : minrow(a.minrow), maxrow(a.maxrow), mincol(a.mincol), maxcol(a.maxcol), - data((maxrow >= minrow) ? new T *[maxrow - minrow + 1] - minrow : nullptr) -{ - for (int i = minrow; i <= maxrow; i++) { - data[i] = (maxcol >= mincol) ? new T[maxcol - mincol + 1] - mincol : nullptr; - for (int j = mincol; j <= maxcol; j++) { - data[i][j] = a.data[i][j]; - } - } -} - -template RectArray::~RectArray() -{ - for (int i = minrow; i <= maxrow; i++) { - if (data[i]) { - delete[] (data[i] + mincol); - } - } - if (data) { - delete[] (data + minrow); - } -} - -template RectArray &RectArray::operator=(const RectArray &a) -{ - if (this != &a) { - for (int i = minrow; i <= maxrow; i++) { - if (data[i]) { - delete[] (data[i] + mincol); - } - } - if (data) { - delete[] (data + minrow); - } - - minrow = a.minrow; - maxrow = a.maxrow; - mincol = a.mincol; - maxcol = a.maxcol; - - data = (maxrow >= minrow) ? new T *[maxrow - minrow + 1] - minrow : nullptr; - - for (int i = minrow; i <= maxrow; i++) { - data[i] = (maxcol >= mincol) ? new T[maxcol - mincol + 1] - mincol : nullptr; - for (int j = mincol; j <= maxcol; j++) { - data[i][j] = a.data[i][j]; - } - } - } - - return *this; -} + /// @name Iteration + /// @{ + using element_iterator = typename std::vector::iterator; + using const_element_iterator = typename std::vector::const_iterator; + + element_iterator elements_begin() noexcept { return m_storage.begin(); } + element_iterator elements_end() noexcept { return m_storage.end(); } + const_element_iterator elements_begin() const noexcept { return m_storage.begin(); } + const_element_iterator elements_end() const noexcept { return m_storage.end(); } + const_element_iterator elements_cbegin() const noexcept { return m_storage.cbegin(); } + const_element_iterator elements_cend() const noexcept { return m_storage.cend(); } + /// @} +}; //------------------------------------------------------------------------ // RectArray: Row and column rotation @@ -204,28 +261,26 @@ template RectArray &RectArray::operator=(const RectArray &a) template void RectArray::RotateUp(int lo, int hi) { - if (lo < minrow || hi < lo || maxrow < hi) { + if (lo < m_minrow || hi < lo || m_maxrow < hi) { throw std::out_of_range("Index out of range in RectArray"); } - - T *temp = data[lo]; - for (int k = lo; k < hi; k++) { - data[k] = data[k + 1]; - } - data[hi] = temp; + const auto stride = row_stride(); + const size_t first = index(lo, m_mincol); + const size_t last = index(hi + 1, m_mincol); + std::rotate(m_storage.begin() + first, m_storage.begin() + first + stride, + m_storage.begin() + last); } template void RectArray::RotateDown(int lo, int hi) { - if (lo < minrow || hi < lo || maxrow < hi) { + if (lo < m_minrow || hi < lo || m_maxrow < hi) { throw std::out_of_range("Index out of range in RectArray"); } - - T *temp = data[hi]; - for (int k = hi; k > lo; k--) { - data[k] = data[k - 1]; - } - data[lo] = temp; + const auto stride = row_stride(); + const size_t first = index(lo, m_mincol); + const size_t last = index(hi + 1, m_mincol); + std::rotate(m_storage.begin() + first, m_storage.begin() + last - stride, + m_storage.begin() + last); } //------------------------------------------------------------------------- @@ -237,12 +292,12 @@ template template void RectArray::GetRow(int row, Ve if (!CheckRow(row)) { throw std::out_of_range("Index out of range in RectArray"); } - if (!CheckRow(v)) { + if (!ConformsToRow(v)) { throw DimensionException(); } - const T *rowptr = data[row]; - for (int i = mincol; i <= maxcol; i++) { - v[i] = rowptr[i]; + const size_t base = index(row, m_mincol); + for (int c = m_mincol; c <= m_maxcol; ++c) { + v[c] = m_storage[base + (c - m_mincol)]; } } @@ -255,11 +310,11 @@ template template void RectArray::GetColumn(int col, if (!CheckColumn(col)) { throw std::out_of_range("Index out of range in RectArray"); } - if (!CheckColumn(v)) { + if (!ConformsToColumn(v)) { throw DimensionException(); } - for (int i = minrow; i <= maxrow; i++) { - v[i] = data[i][col]; + for (int r = m_minrow; r <= m_maxrow; ++r) { + v[r] = m_storage[index(r, col)]; } } @@ -268,11 +323,11 @@ template template void RectArray::SetColumn(int col, if (!CheckColumn(col)) { throw std::out_of_range("Index out of range in RectArray"); } - if (!CheckColumn(v)) { + if (!ConformsToColumn(v)) { throw DimensionException(); } - for (int i = minrow; i <= maxrow; i++) { - data[i][col] = v[i]; + for (int r = m_minrow; r <= m_maxrow; ++r) { + m_storage[index(r, col)] = v[r]; } } diff --git a/src/solvers/logit/path.cc b/src/solvers/logit/path.cc index 007bb2e70..2541ec51f 100644 --- a/src/solvers/logit/path.cc +++ b/src/solvers/logit/path.cc @@ -70,9 +70,17 @@ void Givens(Matrix &b, Matrix &q, double &c1, double &c2, int l1 c2 = 0.0; } +void SetAsIdentity(Matrix &M) +{ + M = 0.0; + for (int i = M.MinRow(); i <= M.MaxRow(); ++i) { + M(i, i) = 1.0; + } +} + void QRDecomp(Matrix &b, Matrix &q) { - q.MakeIdent(); + SetAsIdentity(q); for (size_t m = 1; m <= b.NumColumns(); m++) { for (size_t k = m + 1; k <= b.NumRows(); k++) { Givens(b, q, b(m, m), b(k, m), m, k, m + 1);