forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathInferSize.h
More file actions
128 lines (116 loc) · 3.84 KB
/
InferSize.h
File metadata and controls
128 lines (116 loc) · 3.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#pragma once
#include <ATen/DimVector.h>
#include <c10/core/ScalarType.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/util/DimVector.h>
#include <c10/util/Exception.h>
#include <optional>
#include <sstream>
#include <vector>
namespace at {
// Infers the size of a dim with size -1, if it exists. Also checks that new
// shape is compatible with the number of elements.
//
// templated to handle std::vector<int64_t> and DimVector use cases, see
// below
//
template <typename InputArrayRef, typename NumelType, typename ResultVec>
inline void infer_size_impl(
InputArrayRef shape,
NumelType numel,
ResultVec& res) {
NumelType newsize = 1;
// N.B. this is an index, not a sym dim!
std::optional<int64_t> infer_dim;
for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
if (TORCH_GUARD_OR_FALSE(sym_eq(shape[dim], -1))) {
TORCH_CHECK(!infer_dim, "only one dimension can be inferred");
infer_dim = dim;
} else {
// in case of unbacked shape[dim] we assume it's not -1 and add a runtime
// assertion.
TORCH_MAYBE_SYM_CHECK(
sym_gt(shape[dim], -1),
"invalid shape dimension ",
shape[dim],
" at index ",
dim,
" of shape ",
shape);
newsize *= shape[dim];
}
}
if (infer_dim) {
// numel is the product of known sizes, it has to be divisible by newsize.
// and newsize should be positive unless newsize == numel (we throw
// different) error message in that case.
if constexpr (std::is_same_v<NumelType, c10::SymInt>) {
auto v = newsize.maybe_as_int();
if (v and *v == 0) {
// Avoid div by 0 when sym_eq(numel % newsize, 0) is constructed!
// which may happen when newsize is not a symbol! if its a symbol
// division won't happen anyway during compile.
TORCH_MAYBE_SYM_CHECK(
numel == newsize,
"shape '",
shape,
"' is invalid for input of size ",
numel);
} else {
auto cond = sym_gt(newsize, 0)
.sym_and(sym_eq(numel % newsize, 0))
.sym_or(sym_eq(numel, newsize));
TORCH_MAYBE_SYM_CHECK(
cond, "shape '", shape, "' is invalid for input of size ", numel);
}
} else {
TORCH_CHECK(
(newsize > 0 && (numel % newsize == 0)) || numel == newsize,
"shape '",
shape,
"' is invalid for input of size ",
numel);
}
// We have a degree of freedom here to select the dimension size; follow
// NumPy semantics and just bail. However, a nice error message is needed
// because users often use `view` as a way to flatten & unflatten
// dimensions and will otherwise be confused why
// empty_tensor.view( 0, 0)
// works yet
// empty_tensor.view(-1, 0)
// doesn't.
TORCH_MAYBE_SYM_CHECK(
newsize != 0,
"cannot reshape tensor of 0 elements into shape ",
shape,
" because the unspecified dimension size -1 can be any "
"value and is ambiguous");
res[*infer_dim] = numel / newsize;
return;
}
TORCH_MAYBE_SYM_CHECK(
sym_eq(numel, newsize),
"shape '",
shape,
"' is invalid for input of size ",
numel);
}
inline std::vector<int64_t> infer_size(IntArrayRef shape, int64_t numel) {
auto res = shape.vec();
infer_size_impl(shape, numel, res);
return res;
}
inline at::DimVector infer_size_dv(IntArrayRef shape, int64_t numel) {
auto res = at::DimVector(shape);
infer_size_impl(shape, numel, res);
return res;
}
inline at::SymDimVector infer_size_dv(
c10::SymIntArrayRef shape,
c10::SymInt numel) {
auto res = at::SymDimVector(shape);
infer_size_impl<c10::SymIntArrayRef, c10::SymInt, at::SymDimVector>(
shape, std::move(numel), res);
return res;
}
} // namespace at