Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions adolc/src/py_adolc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,29 @@ adouble wrapped_condassign_adouble_if_else(adouble &res, const adouble &cond, co
return res;
}

double wrapped_condeqassign_double_if(double res, const double cond, const double arg1){
// printf("res = %f\ncond = %f\narg1=%f",res,cond,arg1);
condeqassign(res,cond,arg1);
// printf("after assign res= %f\n",res);
return res;
}

double wrapped_condeqassign_double_if_else(double res, const double cond, const double arg1, const double arg2){
// printf("res = %f\ncond = %f\narg1=%f\narg2=%f\n",res,cond,arg1,arg2);
condeqassign(res,cond,arg1,arg2);
// printf("after assign res= %f\n",res);
return res;
}

adouble wrapped_condeqassign_adouble_if(adouble &res, const adouble &cond, const adouble &arg1){
condeqassign(res,cond,arg1);
return res;
}
adouble wrapped_condeqassign_adouble_if_else(adouble &res, const adouble &cond, const adouble &arg1, const adouble &arg2){
condeqassign(res,cond,arg1,arg2);
return res;
}



/* C STYLE CALLS OF FUNCTIONS */
Expand Down
10 changes: 9 additions & 1 deletion adolc/src/py_adolc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,11 @@ double wrapped_condassign_double_if_else(double res, const double cond, const do
adouble wrapped_condassign_adouble_if(adouble &res, const adouble &cond, const adouble &arg1);
adouble wrapped_condassign_adouble_if_else(adouble &res, const adouble &cond, const adouble &arg1, const adouble &arg2);

double wrapped_condeqassign_double_if(double res, const double cond, const double arg1);
double wrapped_condeqassign_double_if_else(double res, const double cond, const double arg1, const double arg2);


adouble wrapped_condeqassign_adouble_if(adouble &res, const adouble &cond, const adouble &arg1);
adouble wrapped_condeqassign_adouble_if_else(adouble &res, const adouble &cond, const adouble &arg1, const adouble &arg2);

/* THIN WRAPPER FOR OVERLOADED FUNCTIONS */
int trace_on_default_argument(short tape_tag){ return trace_on(tape_tag,0);}
Expand Down Expand Up @@ -278,6 +281,11 @@ BOOST_PYTHON_MODULE(_adolc)
def("condassign", &wrapped_condassign_adouble_if);
def("condassign", &wrapped_condassign_adouble_if_else);

def("condeqassign", &wrapped_condeqassign_double_if);
def("condeqassign", &wrapped_condeqassign_double_if_else);
def("condeqassign", &wrapped_condeqassign_adouble_if);
def("condeqassign", &wrapped_condeqassign_adouble_if_else);

class_<badouble>("badouble", init<const badouble &>())
.def(boost::python::self_ns::str(self))

Expand Down
87 changes: 82 additions & 5 deletions adolc/tests/test_wrapped_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,24 +220,83 @@ def test_double_condassign_if_else(self):
x = condassign(x,cond,y,z)
assert x == 5

def test_double_condeqassign_if(self):
x = 3.
y = 4.
cond = 0.

x = condeqassign(x,cond,y)
print x
assert_almost_equal(x,4.)

x = 3.
y = 4.
cond = -1.
x = condeqassign(x,cond,y)
print x
assert_almost_equal(x,3.)

def test_double_condeqassign_if_else(self):
x = 3.
y = 4.
z = 5.
cond = 0.

x = condeqassign(x,cond,y,z)
assert x == 4.

x = 3.
y = 4.
z = 5.
cond = -1.

x = condeqassign(x,cond,y,z)
assert x == 5

def test_adouble_condassign_if(self):
x = adouble(3.)
trace_on(1)
x = independent(adouble(3.))
y = adouble(4.)
cond = adouble(1.)
cond = adouble(x-1.)
r = dependent(condassign(x,cond,y))
trace_off()
assert_almost_equal(r.val, 4.)
y_arr = function(1,numpy.array([3.0]))
assert_almost_equal(y_arr[0],4.0)
y_arr = function(1,numpy.array([1.0]))
assert_almost_equal(y_arr[0],1.0)
y_arr = function(1,numpy.array([.5]))
assert_almost_equal(y_arr[0],.5)

x = adouble(3.)
y = adouble(4.)
cond = adouble(-3.)
x = condassign(x,cond,y)
print x
assert_almost_equal(x.val, 4.)
assert_almost_equal(x.val, 3.)

def test_adouble_condeqassign_if(self):
trace_on(1)
x = independent(adouble(3.))
y = adouble(4.)
cond = adouble(x-1.)
r = dependent(condeqassign(x,cond,y))
trace_off()
assert_almost_equal(r.val, 4.)
y_arr = function(1,numpy.array([3.0]))
assert_almost_equal(y_arr[0],4.0)
y_arr = function(1,numpy.array([1.0]))
assert_almost_equal(y_arr[0],4.0)
y_arr = function(1,numpy.array([.5]))
assert_almost_equal(y_arr[0],.5)

x = adouble(3.)
y = adouble(4.)
cond = adouble(-3.)
x = condassign(x,cond,y)
x = condeqassign(x,cond,y)
print x
assert_almost_equal(x.val, 3.)


def test_xuchen_condassign(self):
"""
see https://github.com/b45ch1/pyadolc/issues/12
Expand Down Expand Up @@ -288,6 +347,24 @@ def test_adouble_condassign_if_else(self):
print x
assert_almost_equal(x.val, 5.)

def test_adouble_condeqassign_if_else(self):
x = adouble(3.)
y = adouble(4.)
z = adouble(5.)
cond = adouble(0.)

x = condeqassign(x,cond,y,z)
print x
assert_almost_equal(x.val, 4.)

x = adouble(3.)
y = adouble(4.)
z = adouble(5.)
cond = adouble(-3.)

x = condeqassign(x,cond,y,z)
print x
assert_almost_equal(x.val, 5.)

class CorrectnessTests(TestCase):
def test_sin(self):
Expand Down
29 changes: 29 additions & 0 deletions adolc/wrapped_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,35 @@ def c(v):
z = c(z)
return _adolc.condassign(x, cond, y, z)

def condeqassign(x, cond, y, z = None):
"""

x = condeqassign(cond, y, z = None)

equivalent to:
if cond:
x = y

else:
if z != None:
x = z

"""

def c(v):
if isinstance(v, _adolc.adub):
return _adolc.adouble(v)
return v

x = c(x)
cond = c(cond)
y = c(y)

if z == None:
return _adolc.condeqassign(x, cond, y)
else:
z = c(z)
return _adolc.condeqassign(x, cond, y, z)

def tape_to_latex(tape_tag,x,y):
"""
Expand Down
34 changes: 34 additions & 0 deletions tests/misc_tests/adolc_tiny_unit_test/test_adolc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,22 @@ int test_condassign_on_adouble(){
return 0;
}

int test_condeqassign_on_adouble(){
adouble a(2.);
adouble cond(0.);
adouble cond2(-3.);
adouble b(4.);
adouble c(5.);

condeqassign(a,cond,b,c);
if(a != b)return -1;

condeqassign(a,cond2,b,c);
if(a != c)return -1;

return 0;
}

int test_condassign_on_double(){
double a(2.);
double cond(3.);
Expand All @@ -39,6 +55,22 @@ int test_condassign_on_double(){
return 0;
}

int test_condeqassign_on_double(){
double a(2.);
double cond(0.);
double cond2(-3.);
double b(4.);
double c(5.);

condeqassign(a,cond,b,c);
if(a != b) return -1;

condeqassign(a,cond2,b,c);
if(a != c) return -1;

return 0;
}

void check(int error, string test_name){
if(error != 0){
cout<<"Test "<<test_name<<" : failed!"<<endl;
Expand All @@ -50,6 +82,8 @@ void check(int error, string test_name){

int main(){
check(test_condassign_on_adouble(),"test_condassign_on_adouble");
check(test_condeqassign_on_adouble(),"test_condeqassign_on_adouble");
check(test_condassign_on_double(),"test_condassign_on_double");
check(test_condeqassign_on_double(),"test_condeqassign_on_double");
return 0;
}