From 66554451186206c19907ea42e9a6f0bf5e6a7933 Mon Sep 17 00:00:00 2001 From: Shamoon Siddiqui Date: Fri, 11 Dec 2020 12:15:52 -0500 Subject: [PATCH 1/9] Added a period parameter to the transformer (#1) Co-authored-by: Max Cohen --- tst/transformer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tst/transformer.py b/tst/transformer.py index e57f634..1062b13 100644 --- a/tst/transformer.py +++ b/tst/transformer.py @@ -43,11 +43,13 @@ class Transformer(nn.Module): Dropout probability after each MHA or PFF block. Default is ``0.3``. chunk_mode: - Swict between different MultiHeadAttention blocks. + Switch between different MultiHeadAttention blocks. One of ``'chunk'``, ``'window'`` or ``None``. Default is ``'chunk'``. pe: Type of positional encoding to add. Must be one of ``'original'``, ``'regular'`` or ``None``. Default is ``None``. + pe_period: + If using the ``'regular'` pe, then we can define the period. Default is ``24``. """ def __init__(self, @@ -61,7 +63,8 @@ def __init__(self, attention_size: int = None, dropout: float = 0.3, chunk_mode: str = 'chunk', - pe: str = None): + pe: str = None, + pe_period: int = 24): """Create transformer structure from Encoder and Decoder blocks.""" super().__init__() @@ -92,6 +95,7 @@ def __init__(self, if pe in pe_functions.keys(): self._generate_PE = pe_functions[pe] + self._pe_period = pe_period elif pe is None: self._generate_PE = None else: @@ -122,7 +126,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Add position encoding if self._generate_PE is not None: - positional_encoding = self._generate_PE(K, self._d_model) + pe_params = {'period': self._pe_period} if self._pe_period else {} + positional_encoding = self._generate_PE(K, self._d_model, **pe_params) positional_encoding = positional_encoding.to(encoding.device) encoding.add_(positional_encoding) From b74fa8d87780d1f9c2d05a03b106f3b4d7b01f1d Mon Sep 17 00:00:00 2001 From: Shamoon Siddiqui Date: Fri, 11 Dec 2020 12:17:02 -0500 Subject: [PATCH 2/9] Fix --- tst/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tst/transformer.py b/tst/transformer.py index 1062b13..8f24e91 100644 --- a/tst/transformer.py +++ b/tst/transformer.py @@ -64,7 +64,7 @@ def __init__(self, dropout: float = 0.3, chunk_mode: str = 'chunk', pe: str = None, - pe_period: int = 24): + pe_period: int = None): """Create transformer structure from Encoder and Decoder blocks.""" super().__init__() From c62339f11036eefa35c5c9de3272bde9be9362d4 Mon Sep 17 00:00:00 2001 From: Shamoon Siddiqui Date: Fri, 11 Dec 2020 12:18:58 -0500 Subject: [PATCH 3/9] Pe/fix (#2) * Added a period parameter to the transformer * Fix Co-authored-by: Max Cohen --- tst/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tst/transformer.py b/tst/transformer.py index 1062b13..8f24e91 100644 --- a/tst/transformer.py +++ b/tst/transformer.py @@ -64,7 +64,7 @@ def __init__(self, dropout: float = 0.3, chunk_mode: str = 'chunk', pe: str = None, - pe_period: int = 24): + pe_period: int = None): """Create transformer structure from Encoder and Decoder blocks.""" super().__init__() From 820d4a96ad768cb891ed39772e9fcc8250fb4097 Mon Sep 17 00:00:00 2001 From: Shamoon Siddiqui Date: Fri, 18 Dec 2020 13:12:38 -0500 Subject: [PATCH 4/9] Fixed --- tst/transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tst/transformer.py b/tst/transformer.py index 8f24e91..7eed5c6 100644 --- a/tst/transformer.py +++ b/tst/transformer.py @@ -47,9 +47,9 @@ class Transformer(nn.Module): One of ``'chunk'``, ``'window'`` or ``None``. Default is ``'chunk'``. pe: Type of positional encoding to add. - Must be one of ``'original'``, ``'regular'`` or ``None``. Default is ``None``. + Must be one of `original, `regular` or `None. Default is `None`. pe_period: - If using the ``'regular'` pe, then we can define the period. Default is ``24``. + If using the ``regular`` pe, then we can define the period. Default is ``24``. """ def __init__(self, From 7dd79b50d74c85cbb43a04ade7760f046ed9c323 Mon Sep 17 00:00:00 2001 From: Shamoon Siddiqui Date: Fri, 18 Dec 2020 13:18:10 -0500 Subject: [PATCH 5/9] Pe/fix (#3) * Added a period parameter to the transformer * Fix * Fixed Co-authored-by: Max Cohen --- tst/transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tst/transformer.py b/tst/transformer.py index 8f24e91..7eed5c6 100644 --- a/tst/transformer.py +++ b/tst/transformer.py @@ -47,9 +47,9 @@ class Transformer(nn.Module): One of ``'chunk'``, ``'window'`` or ``None``. Default is ``'chunk'``. pe: Type of positional encoding to add. - Must be one of ``'original'``, ``'regular'`` or ``None``. Default is ``None``. + Must be one of `original, `regular` or `None. Default is `None`. pe_period: - If using the ``'regular'` pe, then we can define the period. Default is ``24``. + If using the ``regular`` pe, then we can define the period. Default is ``24``. """ def __init__(self, From 8df46b88ed03509f6ffe5c449684d651093b7940 Mon Sep 17 00:00:00 2001 From: Shamoon Siddiqui Date: Fri, 18 Dec 2020 13:47:06 -0500 Subject: [PATCH 6/9] Edge check --- tst/transformer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tst/transformer.py b/tst/transformer.py index 7eed5c6..9ab9d33 100644 --- a/tst/transformer.py +++ b/tst/transformer.py @@ -95,7 +95,9 @@ def __init__(self, if pe in pe_functions.keys(): self._generate_PE = pe_functions[pe] - self._pe_period = pe_period + + if pe == 'regular' and self._pe_period is not None: + self._pe_period = pe_period elif pe is None: self._generate_PE = None else: From 4ea6b11642907c9bea96df3ec40cb4a49366f232 Mon Sep 17 00:00:00 2001 From: Shamoon Siddiqui Date: Fri, 18 Dec 2020 13:48:15 -0500 Subject: [PATCH 7/9] Pe/fix (#4) * Added a period parameter to the transformer * Fix * Fixed * Edge check Co-authored-by: Max Cohen --- tst/transformer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tst/transformer.py b/tst/transformer.py index 7eed5c6..9ab9d33 100644 --- a/tst/transformer.py +++ b/tst/transformer.py @@ -95,7 +95,9 @@ def __init__(self, if pe in pe_functions.keys(): self._generate_PE = pe_functions[pe] - self._pe_period = pe_period + + if pe == 'regular' and self._pe_period is not None: + self._pe_period = pe_period elif pe is None: self._generate_PE = None else: From f89631864d9db3b4e08fa7f233eab14b8208ad99 Mon Sep 17 00:00:00 2001 From: Shamoon Siddiqui Date: Sun, 20 Dec 2020 11:28:08 -0500 Subject: [PATCH 8/9] Pe/fix (#5) * Added a period parameter to the transformer * Fix * Fixed * Edge check Co-authored-by: Max Cohen From c54af33450202b5cabf1d41e5ada4a2dac3bd5d0 Mon Sep 17 00:00:00 2001 From: Shamoon Siddiqui Date: Sun, 20 Dec 2020 11:29:31 -0500 Subject: [PATCH 9/9] Fix --- tst/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tst/transformer.py b/tst/transformer.py index 9ab9d33..4d5a070 100644 --- a/tst/transformer.py +++ b/tst/transformer.py @@ -96,7 +96,7 @@ def __init__(self, if pe in pe_functions.keys(): self._generate_PE = pe_functions[pe] - if pe == 'regular' and self._pe_period is not None: + if pe == 'regular' and pe_period is not None: self._pe_period = pe_period elif pe is None: self._generate_PE = None