diff --git a/include/cpu_util.h b/include/cpu_util.h index 0fe6f21..9391367 100644 --- a/include/cpu_util.h +++ b/include/cpu_util.h @@ -35,6 +35,7 @@ int invert_cpu(double* A, int size, int mode); int mat_root_cpu(double* A, int size); int mat_root_inv_cpu(double* A, int size); int mat_root_inv_stable_cpu(double* A, int size, double inv_cutoff, int prl); +int LU_inv_stable_cpu(double* A, int size); void trans_cpu(float* Bt, float* B, int m, int n); void trans_cpu(double* Bt, double* B, int m, int n); diff --git a/src/integrals/cpu_util.cpp b/src/integrals/cpu_util.cpp index 7afd923..601d3d2 100644 --- a/src/integrals/cpu_util.cpp +++ b/src/integrals/cpu_util.cpp @@ -33,6 +33,10 @@ extern void dsyevx_( double* work, int * lwork, int* iwork, int* IFAIL, int* info ); + +extern void dgetrf_(int* M, int* N, double* A, int* LDA, int* IPIV, int* INFO); + +extern void dgetri_(int* N, double* A, int* LDA, int* IPIV, double* WORK, int* LWORK, int* INFO); } void print_square(int N, double* A); @@ -778,7 +782,36 @@ int mat_root_inv_stable_cpu(double* A, int size, double inv_cutoff, int prl) return nlow; } +int LU_inv_stable_cpu(double* A, int size) +{ + //Need to check if debug is necessary, but need ZEST incorporation first + int size2 = size*size; + int infolu; + int infoinv; + int* piv = new int[size]; + double* work = new double[size2]; + + dgetrf_(&size,&size,A,&size,piv,&infolu); + if(infolu != 0) + { + printf("LU factorization failed %i\n",infolu); + delete [] piv; + delete [] work; + exit(1); + } + dgetri_(&size,A,&size,piv,work,&size2,&infoinv); + if(infoinv != 0) + { + printf("LU inversion failed %i\n",infoinv); + exit(1); + } + + delete [] piv; + delete [] work; + + return infoinv; +} void trans_cpu(float* Bt, float* B, int m, int n) { @@ -1139,4 +1172,4 @@ void mat_times_mat_at_cpu(double* C, double* A, double* B, int M, int N, int K) dgemm_(&TB,&TA,&N,&M,&K,&ALPHA,B,&LDB,A,&LDA,&BETA,C,&LDC); return; -} \ No newline at end of file +}