forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathPlacement.h
More file actions
115 lines (91 loc) · 2.52 KB
/
Placement.h
File metadata and controls
115 lines (91 loc) · 2.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#pragma once
/**
* The implementations in this file are coupled with
* torch/distributed/tensor/placement_types.py.
*/
#include <cstdint>
#include <optional>
#include <string>
#include <string_view>
namespace torch::distributed {
class Placement {
public:
Placement() = default;
virtual ~Placement() = default;
Placement(const Placement&) = default;
Placement& operator=(const Placement&) = default;
Placement(Placement&&) noexcept = default;
Placement& operator=(Placement&&) noexcept = default;
virtual bool is_shard(std::optional<std::int64_t> dim) const {
return false;
}
virtual bool is_replicate() const {
return false;
}
virtual bool is_partial(
std::optional<std::string_view> reduce_op = std::nullopt) const {
return false;
}
};
class Shard : public Placement {
public:
std::int64_t dim;
explicit Shard(std::int64_t dim) : dim(dim) {}
bool is_shard(std::optional<std::int64_t> dim_) const override {
if (typeid(*this) != typeid(Shard)) {
return false;
}
return !dim_.has_value() || *dim_ == dim;
}
bool operator==(const Shard& rhs) const {
return dim == rhs.dim;
}
bool operator!=(const Shard& rhs) const {
return !operator==(rhs);
}
};
class StridedShard : public Placement {
public:
std::int64_t dim;
std::int64_t split_factor;
explicit StridedShard(std::int64_t dim, std::int64_t split_factor_)
: dim(dim), split_factor(split_factor_) {}
bool operator==(const StridedShard& rhs) const {
return dim == rhs.dim && split_factor == rhs.split_factor;
}
bool operator!=(const StridedShard& rhs) const {
return !operator==(rhs);
}
};
class Replicate : public Placement {
public:
bool is_replicate() const override {
return true;
}
bool operator==(const Replicate& rhs) const {
return true;
}
bool operator!=(const Replicate& rhs) const {
return false;
}
};
class Partial : public Placement {
public:
std::string reduce_op;
Partial() : Partial("sum") {}
explicit Partial(std::optional<std::string> reduce_op_)
: reduce_op(
reduce_op_.has_value() ? std::move(*reduce_op_)
: std::string("sum")) {}
bool is_partial(
std::optional<std::string_view> op = std::nullopt) const override {
return !op.has_value() || *op == reduce_op;
}
bool operator==(const Partial& rhs) const {
return reduce_op == rhs.reduce_op;
}
bool operator!=(const Partial& rhs) const {
return !operator==(rhs);
}
};
} // namespace torch::distributed