|
20 | 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
21 | 21 | # SOFTWARE. |
22 | 22 | # |
23 | | -from typing import NamedTuple, NamedTupleMeta |
24 | | - |
25 | 23 | try: |
26 | 24 | from torch.utils.data._utils.collate import default_collate_fn_map |
27 | 25 | except ImportError: # ad-hoc for pytorch<1.13 |
28 | 26 | default_collate_fn_map = {} |
29 | 27 |
|
30 | 28 |
|
31 | | -# https://stackoverflow.com/a/50369521 |
32 | | -if hasattr(NamedTuple, '__mro_entries__'): |
33 | | - # Python 3.9 fixed and broke multiple inheritance in a different way |
34 | | - # see https://github.com/python/cpython/issues/88089 |
35 | | - from typing import _NamedTuple |
36 | | - |
37 | | - NamedTuple = _NamedTuple |
38 | | - |
39 | | - |
40 | | -class MultipleInheritanceNamedTupleMeta(NamedTupleMeta): |
41 | | - def __new__(mcls, typename, bases, ns): |
42 | | - if NamedTuple in bases: |
43 | | - base = super().__new__(mcls, '_base_' + typename, bases, ns) |
44 | | - bases = (base, *(b for b in bases if not isinstance(b, NamedTuple))) |
45 | | - return super(NamedTupleMeta, mcls).__new__(mcls, typename, bases, ns) |
46 | | - |
47 | | - |
48 | | -class DataTypeMixin(metaclass=MultipleInheritanceNamedTupleMeta): |
49 | | - def to(self, *args, **kwargs): |
50 | | - return type(self)(*(x.to(*args, **kwargs) for x in self)) |
51 | | - |
52 | | - def cpu(self, *args, **kwargs): |
53 | | - return type(self)(*(x.cpu(*args, **kwargs) for x in self)) |
54 | | - |
55 | | - def cuda(self, *args, **kwargs): |
56 | | - return type(self)(*(x.cuda(*args, **kwargs) for x in self)) |
57 | | - |
58 | | - |
59 | | -__all__ = ['DataTypeMixin', 'NamedTuple', 'default_collate_fn_map'] |
| 29 | +__all__ = ['default_collate_fn_map'] |
0 commit comments