Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/cpu_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ int invert_cpu(double* A, int size, int mode);
int mat_root_cpu(double* A, int size);
int mat_root_inv_cpu(double* A, int size);
int mat_root_inv_stable_cpu(double* A, int size, double inv_cutoff, int prl);
int LU_inv_stable_cpu(double* A, int size);

void trans_cpu(float* Bt, float* B, int m, int n);
void trans_cpu(double* Bt, double* B, int m, int n);
Expand Down
35 changes: 34 additions & 1 deletion src/integrals/cpu_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ extern void dsyevx_(
double* work, int * lwork,
int* iwork, int* IFAIL,
int* info );

extern void dgetrf_(int* M, int* N, double* A, int* LDA, int* IPIV, int* INFO);

extern void dgetri_(int* N, double* A, int* LDA, int* IPIV, double* WORK, int* LWORK, int* INFO);
}

void print_square(int N, double* A);
Expand Down Expand Up @@ -778,7 +782,36 @@ int mat_root_inv_stable_cpu(double* A, int size, double inv_cutoff, int prl)
return nlow;
}

int LU_inv_stable_cpu(double* A, int size)
{
//Need to check if debug is necessary, but need ZEST incorporation first
int size2 = size*size;
int infolu;
int infoinv;
int* piv = new int[size];
double* work = new double[size2];

dgetrf_(&size,&size,A,&size,piv,&infolu);
if(infolu != 0)
{
printf("LU factorization failed %i\n",infolu);
delete [] piv;
delete [] work;
exit(1);
}

dgetri_(&size,A,&size,piv,work,&size2,&infoinv);
if(infoinv != 0)
{
printf("LU inversion failed %i\n",infoinv);
exit(1);
}

delete [] piv;
delete [] work;

return infoinv;
}

void trans_cpu(float* Bt, float* B, int m, int n)
{
Expand Down Expand Up @@ -1139,4 +1172,4 @@ void mat_times_mat_at_cpu(double* C, double* A, double* B, int M, int N, int K)
dgemm_(&TB,&TA,&N,&M,&K,&ALPHA,B,&LDB,A,&LDA,&BETA,C,&LDC);

return;
}
}