forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_te.cpp
More file actions
41 lines (35 loc) · 1006 Bytes
/
test_te.cpp
File metadata and controls
41 lines (35 loc) · 1006 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include <gtest/gtest.h>
#include <test/cpp/jit/test_utils.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <iostream>
namespace torch {
namespace jit {
TEST(TETest, RemoveProfiling) {
auto g = std::make_shared<Graph>();
const auto graph_string = R"IR(
graph(%a : Tensor,
%b : bool):
%1 : None = prim::Constant()
%2 : Tensor? = prim::If(%b)
block0():
%3 : Tensor? = prim::profile[profiled_type=Tensor, seen_none=0](%1)
-> (%3)
block1():
%4 : Tensor = prim::profile[profiled_type=Tensor, seen_none=0](%a)
-> (%4)
return (%2))IR";
torch::jit::parseIR(graph_string, g.get());
g->lint();
RemoveProfileNodesAndSpecializeTypes(g);
g->lint();
testing::FileCheck()
.check("prim::Constant")
->check("prim::If")
->check("block")
->check("block")
->check("return")
->run(*g);
}
} // namespace jit
} // namespace torch