Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 213 additions & 0 deletions include/expression_dag.hh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "config.hh"
#include "scalar_node.hh"
#include "assert.hh"
#include "array.hh"


namespace bparser {
Expand All @@ -40,6 +41,8 @@ private:
/// Result nodes, given as input.
NodeVec results;

typedef std::pair<std::string, MultiIdx::VecUint> InvDotNameAndIndices;
typedef std::map<ScalarNodePtr, InvDotNameAndIndices> InvDotMap;

/**
* Used in the setup_result_storage to note number of unclosed nodes
Expand Down Expand Up @@ -102,6 +105,7 @@ public:

/**
* Print ScalarExpression graph in the dot format.
* Useful for debugging
*/
void print_in_dot() {
std::map<ScalarNodePtr , uint> i_node;
Expand Down Expand Up @@ -131,8 +135,217 @@ public:
std::cout << "Node: " << node->op_name_ << "_" << node->result_idx_ << " " << node->result_storage << std::endl;
}

/**
* Print ScalarExpression graph in the common dot format.
* Useful for understanding the DAG.
*/
void print_in_dot2() {
print_in_dot2(InvDotMap());
}

/**
* Print ScalarExpression graph in the common dot format.
* Useful for understanding the DAG. Using the parser's map of var. Name -> Array find the inverse ScalarNodePtr -> var. Name
*/
void print_in_dot2(const std::map<std::string, bparser::Array>& symbols) {
print_in_dot2(create_inverse_map(symbols));
}

/**
* Print ScalarExpression graph in the common dot format.
* Useful for understanding the DAG. Using the map of ScalarNodePtr -> variableName
*/
void print_in_dot2(const InvDotMap& names) {

sort_nodes();

std::cout << "\n" << "----- begin cut here -----" << "\n";
std::cout << "digraph Expr {" << "\n";

std::cout << "/* definitions */" << "\n";

std::cout << "edge [dir=back]" << "\n";
for (uint i = 0; i < sorted.size(); ++i) {
_print_dot_node_definition(sorted[i],names);
}
std::cout << "/* end of definitions */" << "\n";

for (uint i = 0; i < sorted.size(); ++i) {
for (uint in = 0; in < sorted[i]->n_inputs_; ++in) {
std::cout << " ";
_print_dot_node_id(sorted[i]);
std::cout << "\n -> ";
_print_dot_node_id(sorted[i]->inputs_[in]);
std::cout << "\n\n";
}
}
std::cout << "}" << "\n";
std::cout << "----- end cut here -----" << "\n";
std::cout.flush();
}

//Create a map of ScalarNodePtr -> (variable name, is_scalar)
InvDotMap create_inverse_map(const std::map<std::string, bparser::Array>& symbols) const {
InvDotMap inv_map;
if (symbols.empty()) return inv_map;
for (const auto& s : symbols)
{
for (MultiIdx idx(s.second.range()); idx.valid();idx.inc_src()) {
inv_map[s.second[idx]] = InvDotNameAndIndices(s.first, idx.indices());
}
/*for (const auto& n : s.second.elements()) {
inv_map[n] = InvDotNameAndIndices(s.first, s.second.shape().empty());
}*/
}
return inv_map;
}


private:
//Print the vertice identifier for dot
void _print_dot_node_id(const ScalarNodePtr& node) const {
std::cout << node->op_name_ << "_" << (uintptr_t)node.get() << "__" << node->result_storage;// << std::endl;
}

//Print how the vertice should look in dot
void _print_dot_node_definition(const ScalarNodePtr& node, const InvDotMap& invmap) const {
_print_dot_node_id(node);
std::cout << ' ';

switch (node->result_storage) {
case ResultStorage::constant: { // Constant
std::cout << "[shape=circle,";

try { //If the constant has a name
std::string name(invmap.at(node).first);
const MultiIdx::VecUint indices(invmap.at(node).second);
bool scalar(indices.empty());
if (scalar) {
std::cout << "label=\"" << name << '=' << *node->values_ << '\"';
}
else {
MultiIdx::VecUint::size_type size(indices.size());
std::cout << "label=\"" << name << "[";
for (MultiIdx::VecUint::size_type i = 0; i < size; i++) {
std::cout << indices.at(i);
if (i != size - 1) {
std::cout << ',';
}
}
std::cout << "]";
std::cout << '=' << *node->values_ << '\"';
}
std::cout << ", group = \"" << name << '\"';
}
catch (const std::out_of_range&) { //No name
std::cout << "label=\"" << "const " << *node->values_ << '"';
}
std::cout << "]" << std::endl;
break;
}

case ResultStorage::constant_bool: { //Constant bool
std::cout << "[shape=circle,";

try { //If the constant has a name
std::string name(invmap.at(node).first);
const MultiIdx::VecUint indices(invmap.at(node).second);
bool scalar(indices.empty());
if (scalar) {
std::cout << "label=\"" << name << '=' << *node->values_ << '\"';
}
else {
MultiIdx::VecUint::size_type size(indices.size());
std::cout << "label=\"" << name << "[";
for (MultiIdx::VecUint::size_type i = 0; i < size; i++) {
std::cout << indices.at(i);
if (i != size - 1) {
std::cout << ',';
}
}
std::cout << "]";
std::cout << '=' << *node->values_ << '\"';
}
std::cout << ", group = \"" << name << '\"';
}
catch (const std::out_of_range&) { //No name
std::cout << "label=\"" << "const " << *node->values_ << '"';
}
std::cout << "]" << std::endl;
break;
}

case ResultStorage::expr_result: { //Result
std::cout << "[shape=box,label=\"" << node->op_name_ << " [" << node->result_idx_ << "]" << "\"]" << std::endl;
break;
}

case ResultStorage::value: { // Value
std::cout << "[shape=circle,";
try {
std::string name(invmap.at(node).first);
const MultiIdx::VecUint indices(invmap.at(node).second);
bool scalar(indices.empty());
if (scalar) {
std::cout << "label=\"" << name << '"';
}
else {
MultiIdx::VecUint::size_type size(indices.size());
std::cout << "label=\"" << name << "[";
for (MultiIdx::VecUint::size_type i = 0; i < size; i++) {
std::cout << indices.at(i);
if (i != size - 1) {
std::cout << ',';
}
}
std::cout << "]\"";
}
std::cout << ",group=\"" << name << '"';
}
catch (const std::out_of_range&) {
std::cout << "label=<<I>var</I>>";
}

std::cout << "]" << std::endl;
break;
}

case ResultStorage::value_copy: { //Value copy
std::cout << "[shape=circle,";
try {
std::string name(invmap.at(node).first);
MultiIdx::VecUint indices(invmap.at(node).second);
bool scalar(indices.empty());
if (scalar) {
std::cout << "label=\"" << name << '"';
}
else {
MultiIdx::VecUint::size_type size(indices.size());
std::cout << "label=\"" << name << "[";
for (MultiIdx::VecUint::size_type i = 0; i < size; i++) {
std::cout << indices.at(i);
if (i != size - 1) {
std::cout << ',';
}
}
std::cout << "]\"";
}
std::cout << ",group=\"" << name << '"';
}
catch (const std::out_of_range&) {
std::cout << "label=<<I>var_cp</I>>";
}
std::cout << "]" << std::endl;
break;
}

default: { //Temporary & other
std::cout << "[label=\"" << node->op_name_ << "\"]" << std::endl;
break;
}
} //switch
}

void _print_i_node(uint i) {
std::cout << sorted[i]->op_name_ << "_" << i << "_"<< sorted[i]->result_idx_;
}
Expand Down
2 changes: 2 additions & 0 deletions include/parser.hh
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ public:
details::ExpressionDAG se(result_array_.elements());

//se.print_in_dot();
//se.print_in_dot2();
//se.print_in_dot2(symbols_);
processor = ProcessorBase::create_processor(se, max_vec_size, simd_size, arena);
}

Expand Down