diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index e78017053..ef12720a2 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -2245,14 +2245,6 @@ "lineCount": 1 } }, - { - "code": "reportDeprecated", - "range": { - "startColumn": 17, - "endColumn": 22, - "lineCount": 1 - } - }, { "code": "reportUnreachable", "range": { @@ -3188,22 +3180,6 @@ "endColumn": 46, "lineCount": 1 } - }, - { - "code": "reportAny", - "range": { - "startColumn": 8, - "endColumn": 13, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 29, - "endColumn": 34, - "lineCount": 1 - } } ], "./pytato/cmath.py": [ @@ -7033,14 +7009,6 @@ "lineCount": 1 } }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 20, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -11397,14 +11365,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 19, - "endColumn": 30, - "lineCount": 1 - } - }, { "code": "reportUnknownLambdaType", "range": { @@ -11487,22 +11447,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 23, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 29, - "endColumn": 40, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -11527,14 +11471,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 21, - "endColumn": 32, - "lineCount": 1 - } - }, { "code": "reportUnusedParameter", "range": { @@ -11546,23 +11482,23 @@ { "code": "reportUnknownParameterType", "range": { - "startColumn": 28, - "endColumn": 39, + "startColumn": 18, + "endColumn": 22, "lineCount": 1 } }, { "code": "reportUnknownParameterType", "range": { - "startColumn": 20, - "endColumn": 31, + "startColumn": 15, + "endColumn": 23, "lineCount": 1 } }, { "code": "reportUnknownParameterType", "range": { - "startColumn": 22, + "startColumn": 25, "endColumn": 33, "lineCount": 1 } @@ -11571,175 +11507,175 @@ "code": "reportUnknownParameterType", "range": { "startColumn": 35, - "endColumn": 42, + "endColumn": 40, "lineCount": 1 } }, { "code": "reportUnknownParameterType", "range": { - "startColumn": 44, - "endColumn": 51, + "startColumn": 4, + "endColumn": 16, "lineCount": 1 } }, { "code": "reportUnknownParameterType", "range": { - "startColumn": 18, - "endColumn": 22, + "startColumn": 17, + "endColumn": 18, "lineCount": 1 } }, { "code": "reportUnknownParameterType", "range": { - "startColumn": 15, - "endColumn": 23, + "startColumn": 8, + "endColumn": 15, "lineCount": 1 } }, { "code": "reportUnknownParameterType", "range": { - "startColumn": 25, - "endColumn": 33, + "startColumn": 17, + "endColumn": 21, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 35, - "endColumn": 40, + "startColumn": 32, + "endColumn": 45, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 23, - "endColumn": 34, + "startColumn": 59, + "endColumn": 72, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 36, - "endColumn": 44, + "startColumn": 32, + "endColumn": 45, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 46, - "endColumn": 54, + "startColumn": 59, + "endColumn": 72, "lineCount": 1 } }, { "code": "reportUnknownParameterType", "range": { - "startColumn": 26, - "endColumn": 37, + "startColumn": 33, + "endColumn": 42, "lineCount": 1 } }, { "code": "reportUnknownParameterType", "range": { - "startColumn": 39, - "endColumn": 47, + "startColumn": 4, + "endColumn": 24, "lineCount": 1 } }, { "code": "reportUnknownParameterType", "range": { - "startColumn": 49, - "endColumn": 57, + "startColumn": 25, + "endColumn": 30, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 22, - "endColumn": 33, + "startColumn": 15, + "endColumn": 35, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 40, - "endColumn": 51, + "startColumn": 10, + "endColumn": 38, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 14, - "endColumn": 25, + "startColumn": 19, + "endColumn": 33, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 27, - "endColumn": 32, + "startColumn": 19, + "endColumn": 66, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnusedVariable", "range": { - "startColumn": 34, - "endColumn": 38, + "startColumn": 4, + "endColumn": 5, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 26, + "startColumn": 10, "endColumn": 37, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 39, - "endColumn": 43, + "startColumn": 18, + "endColumn": 30, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 19, - "endColumn": 30, + "startColumn": 22, + "endColumn": 34, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 4, - "endColumn": 16, + "startColumn": 14, + "endColumn": 24, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 17, - "endColumn": 18, + "startColumn": 14, + "endColumn": 33, "lineCount": 1 } }, @@ -11747,844 +11683,116 @@ "code": "reportUnknownParameterType", "range": { "startColumn": 8, - "endColumn": 15, + "endColumn": 27, "lineCount": 1 } }, { "code": "reportUnknownParameterType", "range": { - "startColumn": 17, - "endColumn": 21, + "startColumn": 28, + "endColumn": 33, "lineCount": 1 } }, { "code": "reportUnknownParameterType", "range": { - "startColumn": 35, - "endColumn": 46, + "startColumn": 8, + "endColumn": 20, "lineCount": 1 } }, { "code": "reportUnknownParameterType", "range": { - "startColumn": 48, - "endColumn": 53, + "startColumn": 21, + "endColumn": 23, "lineCount": 1 } }, { "code": "reportUnknownParameterType", "range": { - "startColumn": 55, - "endColumn": 62, + "startColumn": 25, + "endColumn": 27, "lineCount": 1 } }, { - "code": "reportUnknownMemberType", + "code": "reportUnknownParameterType", "range": { - "startColumn": 32, - "endColumn": 45, + "startColumn": 8, + "endColumn": 20, "lineCount": 1 } }, { - "code": "reportUnknownMemberType", + "code": "reportUnknownParameterType", "range": { - "startColumn": 59, - "endColumn": 72, + "startColumn": 21, + "endColumn": 23, "lineCount": 1 } }, { "code": "reportUnknownParameterType", "range": { - "startColumn": 34, - "endColumn": 45, + "startColumn": 25, + "endColumn": 27, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportArgumentType", "range": { - "startColumn": 47, - "endColumn": 52, + "startColumn": 37, + "endColumn": 44, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 54, - "endColumn": 61, + "startColumn": 15, + "endColumn": 70, "lineCount": 1 } }, { "code": "reportUnknownMemberType", "range": { - "startColumn": 32, - "endColumn": 45, + "startColumn": 15, + "endColumn": 43, "lineCount": 1 } }, { "code": "reportUnknownMemberType", "range": { - "startColumn": 59, - "endColumn": 72, + "startColumn": 15, + "endColumn": 70, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportUnknownMemberType", "range": { - "startColumn": 22, - "endColumn": 33, + "startColumn": 15, + "endColumn": 43, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", + "code": "reportArgumentType", "range": { - "startColumn": 35, - "endColumn": 40, + "startColumn": 40, + "endColumn": 52, "lineCount": 1 } }, { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 21, - "endColumn": 32, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 34, - "endColumn": 39, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 33, - "endColumn": 42, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 4, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 25, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 15, - "endColumn": 26, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 28, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 15, - "endColumn": 35, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 15, - "endColumn": 26, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 28, - "endColumn": 38, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 21, - "endColumn": 32, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 17, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 30, - "endColumn": 38, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 40, - "endColumn": 48, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 50, - "endColumn": 55, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 38, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 19, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 19, - "endColumn": 66, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 30, - "endColumn": 41, - "lineCount": 1 - } - }, - { - "code": "reportUnusedVariable", - "range": { - "startColumn": 4, - "endColumn": 5, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 37, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 18, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 22, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 24, - "endColumn": 35, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 37, - "endColumn": 42, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 44, - "endColumn": 57, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 31, - "endColumn": 42, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 44, - "endColumn": 49, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 51, - "endColumn": 64, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 25, - "endColumn": 36, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 38, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 40, - "endColumn": 51, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 22, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 35, - "endColumn": 41, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 43, - "endColumn": 49, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 14, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 14, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 25, - "endColumn": 36, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 38, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 28, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 20, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 10, - "endColumn": 26, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 43, - "endColumn": 54, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 12, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 13, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 33, - "endColumn": 44, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 42, - "endColumn": 53, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 45, - "endColumn": 56, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 20, - "endColumn": 31, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 33, - "endColumn": 37, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 39, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 45, - "endColumn": 50, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 16, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 29, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 35, - "endColumn": 44, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 42, - "endColumn": 53, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 20, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 21, - "endColumn": 23, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 25, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 8, - "endColumn": 20, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 21, - "endColumn": 23, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 25, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 37, - "endColumn": 44, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 66, - "endColumn": 77, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 37, - "endColumn": 48, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 37, - "endColumn": 48, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 13, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 26, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 29, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 32, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 16, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 48, - "endColumn": 59, - "lineCount": 1 - } - }, - { - "code": "reportUnusedParameter", - "range": { - "startColumn": 48, - "endColumn": 59, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 61, - "endColumn": 66, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 68, - "endColumn": 76, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 13, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 26, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 35, - "endColumn": 42, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 44, - "endColumn": 51, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 53, - "endColumn": 60, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 15, - "endColumn": 70, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 15, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 15, - "endColumn": 70, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 15, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 40, - "endColumn": 52, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 14, - "endColumn": 25, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 27, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 36, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 45, - "endColumn": 52, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 54, - "endColumn": 61, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", + "code": "reportUnknownMemberType", "range": { "startColumn": 15, "endColumn": 70, @@ -12639,38 +11847,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 29, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 22, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 45, - "endColumn": 56, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 32, - "endColumn": 43, - "lineCount": 1 - } - }, { "code": "reportUnusedVariable", "range": { @@ -12715,23 +11891,7 @@ "code": "reportUnknownMemberType", "range": { "startColumn": 16, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 28, - "endColumn": 39, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 35, - "endColumn": 46, + "endColumn": 33, "lineCount": 1 } }, @@ -12791,14 +11951,6 @@ "lineCount": 4 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 27, - "endColumn": 38, - "lineCount": 1 - } - }, { "code": "reportUnusedParameter", "range": { @@ -12823,38 +11975,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 17, - "endColumn": 28, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 30, - "endColumn": 35, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 34, - "endColumn": 45, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 46, - "endColumn": 57, - "lineCount": 1 - } - }, { "code": "reportImplicitOverride", "range": { @@ -12871,22 +11991,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 27, - "endColumn": 38, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 45, - "endColumn": 56, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -12951,14 +12055,6 @@ "lineCount": 2 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 34, - "endColumn": 45, - "lineCount": 1 - } - }, { "code": "reportArgumentType", "range": { @@ -12999,62 +12095,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 41, - "endColumn": 52, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 21, - "endColumn": 32, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 19, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 32, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 44, - "endColumn": 55, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 46, - "endColumn": 57, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 64, - "endColumn": 75, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -13111,14 +12151,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 13, - "endColumn": 24, - "lineCount": 1 - } - }, { "code": "reportArgumentType", "range": { @@ -13143,14 +12175,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 23, - "endColumn": 34, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -13279,14 +12303,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 31, - "endColumn": 42, - "lineCount": 1 - } - }, { "code": "reportUnknownLambdaType", "range": { @@ -13423,14 +12439,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 25, - "endColumn": 36, - "lineCount": 1 - } - }, { "code": "reportArgumentType", "range": { @@ -13550,14 +12558,6 @@ "endColumn": 61, "lineCount": 1 } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 26, - "endColumn": 37, - "lineCount": 1 - } } ], "./test/test_distributed.py": [ @@ -13601,14 +12601,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 41, - "endColumn": 52, - "lineCount": 1 - } - }, { "code": "reportUnusedParameter", "range": { @@ -13681,14 +12673,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 46, - "endColumn": 57, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -13697,30 +12681,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 39, - "endColumn": 50, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 50, - "endColumn": 61, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 40, - "endColumn": 51, - "lineCount": 1 - } - }, { "code": "reportPossiblyUnboundVariable", "range": { @@ -13737,14 +12697,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 68, - "endColumn": 79, - "lineCount": 1 - } - }, { "code": "reportPossiblyUnboundVariable", "range": { @@ -13761,14 +12713,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 60, - "endColumn": 71, - "lineCount": 1 - } - }, { "code": "reportPossiblyUnboundVariable", "range": { @@ -13785,14 +12729,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 48, - "endColumn": 59, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -13801,14 +12737,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 35, - "endColumn": 46, - "lineCount": 1 - } - }, { "code": "reportUnusedParameter", "range": { @@ -13833,22 +12761,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 25, - "endColumn": 36, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 37, - "endColumn": 48, - "lineCount": 1 - } - }, { "code": "reportUnusedParameter", "range": { @@ -13857,14 +12769,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 43, - "endColumn": 54, - "lineCount": 1 - } - }, { "code": "reportUnusedParameter", "range": { diff --git a/doc/misc.rst b/doc/misc.rst index e6be12f33..4bd35a7fd 100644 --- a/doc/misc.rst +++ b/doc/misc.rst @@ -46,6 +46,12 @@ Cross-References for Other Documentation The type of the Python-builtin :data:`Ellipsis` object. (not otherwise documented) +.. currentmodule:: prim + +.. class:: NaN + + See :class:`pymbolic.primitives.NaN`. + .. currentmodule:: loopy.kernel .. class:: LoopKernel diff --git a/pytato/array.py b/pytato/array.py index 33a62dd63..bf2068dca 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -48,9 +48,15 @@ .. autoclass:: DictOfNamedArrays .. autoclass:: AbstractResultWithNamedArrays +.. currentmodule:: pytato.array + +.. autoclass:: ArrayOrScalarT + NumPy-Like Interface -------------------- +.. currentmodule:: pytato + These functions generally follow the interface of the corresponding functions in :mod:`numpy`, but not all NumPy features may be supported. @@ -216,25 +222,34 @@ INT_CLASSES, SCALAR_CLASSES, ScalarExpression, + TypeCast, get_reduction_induction_variables, ) if TYPE_CHECKING: + from numpy.typing import DTypeLike _dtype_any = np.dtype[Any] else: _dtype_any = np.dtype +# {{{ typing helpers + AxesT = tuple["Axis", ...] ArrayT = TypeVar("ArrayT", bound="Array") +ArrayOrScalar: TypeAlias = "Array | Scalar" -# {{{ shape +ArrayOrScalarT = TypeVar("ArrayOrScalarT", "Array", Scalar, ArrayOrScalar) -ShapeComponent = Union[Integer, "Array"] -ShapeType = tuple[ShapeComponent, ...] -ConvertibleToShape = ShapeComponent | Sequence[ShapeComponent] +ShapeComponent: TypeAlias = "Integer | Array" +ShapeType: TypeAlias = tuple[ShapeComponent, ...] +ConvertibleToShape: TypeAlias = "ShapeComponent | Sequence[ShapeComponent]" +# }}} + + +# {{{ shape def _check_identifier(s: str | None, optional: bool) -> bool: if s is None: @@ -898,7 +913,7 @@ def __xor__(self, other: ArrayOrScalar) -> Array: def __rxor__(self, other: ArrayOrScalar) -> Array: return self._binary_op(operator.xor, other, reverse=True) - def conj(self) -> ArrayOrScalar: + def conj(self) -> Array: import pytato as pt return pt.conj(self) @@ -913,15 +928,32 @@ def __bool__(self) -> None: raise ValueError("The truth value of an array expression is undefined.") @property - def real(self) -> ArrayOrScalar: + def real(self) -> Array: import pytato as pt return pt.real(self) @property - def imag(self) -> ArrayOrScalar: + def imag(self) -> Array: import pytato as pt return pt.imag(self) + def astype(self, dtype: DTypeLike) -> Array: + dtype = np.dtype(dtype) + if self.dtype.kind in ["f", "c"] and dtype.kind in ["i", "u"]: + raise NotImplementedError("numpy-like overflow behavior in float-to-int") + if self.dtype.kind == "c" and dtype.kind in ["i", "u", "f"]: + raise NotImplementedError("complex-to-real casts fail in loopy") + + from pymbolic import var + return make_index_lambda( + TypeCast(dtype, var("in_0")[ + tuple(var(f"_{i}") for i in range(self.ndim)) + ]), + bindings={"in_0": self}, + shape=self.shape, + dtype=dtype, + ) + def reshape(self, *shape: int | Sequence[int], order: str = "C") -> Array: import pytato as pt if len(shape) == 0: @@ -934,14 +966,18 @@ def reshape(self, *shape: int | Sequence[int], order: str = "C") -> Array: # expected "Union[int, Sequence[int]]" return pt.reshape(self, shape, order=order) # type: ignore[arg-type] - def all(self, axis: int = 0) -> ArrayOrScalar: + def all(self, + axis: int | tuple[int, ...] | None = None, + ) -> ArrayOrScalar: """ Equivalent to :func:`pytato.all`. """ import pytato as pt return pt.all(self, axis) - def any(self, axis: int = 0) -> ArrayOrScalar: + def any(self, + axis: int | tuple[int, ...] | None = None, + ) -> ArrayOrScalar: """ Equivalent to :func:`pytato.any`. """ @@ -965,9 +1001,6 @@ def __repr__(self) -> str: from pytato.stringifier import Reprifier return Reprifier()(self) - -ArrayOrScalar: TypeAlias = Array | Scalar - # }}} @@ -2760,7 +2793,17 @@ def greater_equal(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool: # {{{ logical operations -def logical_or(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool: +@overload +def logical_or(x1: Scalar, x2: Scalar, /) -> bool: ... + +@overload +def logical_or(x1: ArrayOrScalar, x2: Array, /) -> Array: ... + +@overload +def logical_or(x1: Array, x2: ArrayOrScalar, /) -> Array: ... + + +def logical_or(x1: ArrayOrScalar, x2: ArrayOrScalar, /) -> Array | bool: """ Returns the element-wise logical OR of *x1* and *x2*. """ @@ -2778,6 +2821,16 @@ def logical_or(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool: ) # type: ignore[return-value] +@overload +def logical_and(x1: Scalar, x2: Scalar, /) -> bool: ... + +@overload +def logical_and(x1: ArrayOrScalar, x2: Array, /) -> Array: ... + +@overload +def logical_and(x1: Array, x2: ArrayOrScalar, /) -> Array: ... + + def logical_and(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool: """ Returns the element-wise logical AND of *x1* and *x2*. @@ -2890,8 +2943,7 @@ def maximum(x1: ArrayOrScalar, x2: ArrayOrScalar) -> ArrayOrScalar: or np.issubdtype(common_dtype, np.complexfloating)): from pytato.cmath import isnan return where(logical_or(isnan(x1), isnan(x2)), - # I don't know why pylint thinks common_dtype is a tuple. - common_dtype.type(np.nan), # pylint: disable=no-member + common_dtype.type(np.nan), where(greater(x1, x2), x1, x2)) else: return where(greater(x1, x2), x1, x2) @@ -2909,8 +2961,7 @@ def minimum(x1: ArrayOrScalar, x2: ArrayOrScalar) -> ArrayOrScalar: or np.issubdtype(common_dtype, np.complexfloating)): from pytato.cmath import isnan return where(logical_or(isnan(x1), isnan(x2)), - # I don't know why pylint thinks common_dtype is a tuple. - common_dtype.type(np.nan), # pylint: disable=no-member + common_dtype.type(np.nan), where(less(x1, x2), x1, x2)) else: return where(less(x1, x2), x1, x2) @@ -2924,7 +2975,7 @@ def make_index_lambda( expression: str | ScalarExpression, bindings: Mapping[str, Array], shape: ShapeType, - dtype: Any, + dtype: DTypeLike, var_to_reduction_descr: Mapping[str, ReductionDescriptor] | None = None ) -> IndexLambda: if isinstance(expression, str): @@ -2966,7 +3017,7 @@ def make_index_lambda( return IndexLambda(expr=expression, bindings=immutabledict(bindings), shape=shape, - dtype=dtype, + dtype=np.dtype(dtype), tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(), axes=_get_default_axes(len(shape)), diff --git a/pytato/cmath.py b/pytato/cmath.py index f68e4897b..9c787edef 100644 --- a/pytato/cmath.py +++ b/pytato/cmath.py @@ -63,11 +63,12 @@ from immutabledict import immutabledict import pymbolic.primitives as prim -from pymbolic import Scalar, var +from pymbolic import var from pytato.array import ( Array, ArrayOrScalar, + ArrayOrScalarT, IndexLambda, _dtype_any, _get_created_at_tag, @@ -81,17 +82,17 @@ from pymbolic.typing import Expression -def _apply_elem_wise_func(inputs: tuple[ArrayOrScalar, ...], +def _apply_elem_wise_func(inputs: tuple[ArrayOrScalarT, ...], func_name: str, ret_dtype: _dtype_any | None = None, np_func_name: str | None = None - ) -> ArrayOrScalar: + ) -> ArrayOrScalarT: if all(isinstance(x, SCALAR_CLASSES) for x in inputs): if np_func_name is None: np_func_name = func_name np_func = getattr(np, np_func_name) - return cast("ArrayOrScalar", np_func(*inputs)) + return cast("ArrayOrScalarT", np_func(*inputs)) if not inputs: raise ValueError("at least one argument must be present") @@ -126,7 +127,7 @@ def _apply_elem_wise_func(inputs: tuple[ArrayOrScalar, ...], assert shape is not None assert ret_dtype is not None - return IndexLambda( + return cast("ArrayOrScalarT", IndexLambda( expr=prim.Call(var(f"pytato.c99.{func_name}"), tuple(sym_args)), shape=shape, dtype=ret_dtype, bindings=immutabledict(bindings), @@ -134,7 +135,7 @@ def _apply_elem_wise_func(inputs: tuple[ArrayOrScalar, ...], non_equality_tags=_get_created_at_tag(stacklevel=2), axes=_get_default_axes(len(shape)), var_to_reduction_descr=immutabledict(), - ) + )) def _get_dtype(x: ArrayOrScalar) -> _dtype_any: @@ -147,7 +148,7 @@ def _get_dtype(x: ArrayOrScalar) -> _dtype_any: # FIXME: Overload these instead of returning union type? -def abs(x: ArrayOrScalar) -> ArrayOrScalar: +def abs(x: ArrayOrScalarT) -> ArrayOrScalarT: x_dtype = _get_dtype(x) if x_dtype.kind == "c": result_dtype = np.empty(0, dtype=x_dtype).real.dtype @@ -157,73 +158,73 @@ def abs(x: ArrayOrScalar) -> ArrayOrScalar: return _apply_elem_wise_func((x,), "abs", ret_dtype=result_dtype) -def sqrt(x: ArrayOrScalar) -> ArrayOrScalar: +def sqrt(x: ArrayOrScalarT) -> ArrayOrScalarT: return _apply_elem_wise_func((x,), "sqrt") -def sin(x: ArrayOrScalar) -> ArrayOrScalar: +def sin(x: ArrayOrScalarT) -> ArrayOrScalarT: return _apply_elem_wise_func((x,), "sin") -def cos(x: ArrayOrScalar) -> ArrayOrScalar: +def cos(x: ArrayOrScalarT) -> ArrayOrScalarT: return _apply_elem_wise_func((x,), "cos") -def tan(x: ArrayOrScalar) -> ArrayOrScalar: +def tan(x: ArrayOrScalarT) -> ArrayOrScalarT: return _apply_elem_wise_func((x,), "tan") -def arcsin(x: ArrayOrScalar) -> ArrayOrScalar: +def arcsin(x: ArrayOrScalarT) -> ArrayOrScalarT: return _apply_elem_wise_func((x,), "asin", np_func_name="arcsin") -def arccos(x: ArrayOrScalar) -> ArrayOrScalar: +def arccos(x: ArrayOrScalarT) -> ArrayOrScalarT: return _apply_elem_wise_func((x,), "acos", np_func_name="arccos") -def arctan(x: ArrayOrScalar) -> ArrayOrScalar: +def arctan(x: ArrayOrScalarT) -> ArrayOrScalarT: return _apply_elem_wise_func((x,), "atan", np_func_name="arctan") -def conj(x: ArrayOrScalar) -> ArrayOrScalar: +def conj(x: ArrayOrScalarT) -> ArrayOrScalarT: if _get_dtype(x).kind != "c": return x return _apply_elem_wise_func((x,), "conj") -def arctan2(y: ArrayOrScalar, x: ArrayOrScalar) -> ArrayOrScalar: +def arctan2(y: ArrayOrScalarT, x: ArrayOrScalarT) -> ArrayOrScalarT: return _apply_elem_wise_func((y, x), "atan2", np_func_name="arctan2") -def sinh(x: ArrayOrScalar) -> ArrayOrScalar: +def sinh(x: ArrayOrScalarT) -> ArrayOrScalarT: return _apply_elem_wise_func((x,), "sinh") -def cosh(x: ArrayOrScalar) -> ArrayOrScalar: +def cosh(x: ArrayOrScalarT) -> ArrayOrScalarT: return _apply_elem_wise_func((x,), "cosh") -def tanh(x: ArrayOrScalar) -> ArrayOrScalar: +def tanh(x: ArrayOrScalarT) -> ArrayOrScalarT: return _apply_elem_wise_func((x,), "tanh") -def exp(x: ArrayOrScalar) -> ArrayOrScalar: +def exp(x: ArrayOrScalarT) -> ArrayOrScalarT: return _apply_elem_wise_func((x,), "exp") -def log(x: ArrayOrScalar) -> ArrayOrScalar: +def log(x: ArrayOrScalarT) -> ArrayOrScalarT: return _apply_elem_wise_func((x,), "log") -def log10(x: ArrayOrScalar) -> ArrayOrScalar: +def log10(x: ArrayOrScalarT) -> ArrayOrScalarT: return _apply_elem_wise_func((x,), "log10") -def isnan(x: ArrayOrScalar) -> ArrayOrScalar: +def isnan(x: ArrayOrScalarT) -> ArrayOrScalarT: return _apply_elem_wise_func((x,), "isnan", np.dtype(np.int32)) -def real(x: ArrayOrScalar) -> ArrayOrScalar: +def real(x: ArrayOrScalarT) -> ArrayOrScalarT: x_dtype = _get_dtype(x) if x_dtype.kind == "c": result_dtype = np.empty(0, dtype=x_dtype).real.dtype @@ -232,17 +233,17 @@ def real(x: ArrayOrScalar) -> ArrayOrScalar: return _apply_elem_wise_func((x,), "real", ret_dtype=result_dtype) -def imag(x: ArrayOrScalar) -> ArrayOrScalar: +def imag(x: ArrayOrScalarT) -> ArrayOrScalarT: x_dtype = _get_dtype(x) if x_dtype.kind == "c": result_dtype = np.empty(0, dtype=x_dtype).real.dtype else: if np.isscalar(x): - return cast("Scalar", x_dtype.type(0)) + return cast("ArrayOrScalarT", x_dtype.type(0)) else: assert isinstance(x, Array) import pytato as pt - return pt.zeros(x.shape, dtype=x_dtype) + return cast("ArrayOrScalarT", pt.zeros(x.shape, dtype=x_dtype)) return _apply_elem_wise_func((x,), "imag", ret_dtype=result_dtype) # vim: fdm=marker diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 6d6fb2301..ea25defaa 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -52,7 +52,7 @@ import numpy as np from immutabledict import immutabledict -from typing_extensions import Never, TypeIs +from typing_extensions import Never, TypeIs, override import pymbolic.primitives as prim from pymbolic import ArithmeticExpression, Bool, Expression, expr_dataclass @@ -180,6 +180,7 @@ def __init__(self, *, composite_leaves=composite_leaves) self.include_idx_lambda_indices = include_idx_lambda_indices + @override def map_variable(self, expr: prim.Variable, *args: P.args, **kwargs: P.kwargs ) -> DependenciesT: @@ -196,9 +197,12 @@ def map_reduce(self, expr: Reduce, set().union(*(self.rec((lb, ub), *args, **kwargs) for (lb, ub) in expr.bounds.values()))]) + def map_type_cast(self, expr: TypeCast, + *args: P.args, **kwargs: P.kwargs) -> DependenciesT: + return self.rec(expr.inner_expr, *args, **kwargs) -class EvaluationMapper(EvaluationMapperBase[ResultT]): +class EvaluationMapper(EvaluationMapperBase[ResultT]): def map_reduce(self, expr: Reduce) -> Never: # TODO: not trivial to evaluate symbolic reduction nodes raise NotImplementedError() diff --git a/test/test_apps.py b/test/test_apps.py index bdb3afc14..410c85510 100644 --- a/test/test_apps.py +++ b/test/test_apps.py @@ -126,7 +126,7 @@ def make_fft_realization_mapper(fft_vec_gatherer): return FFTRealizationMapper(old_array_to_new_array) -def test_trace_fft(ctx_factory): +def test_trace_fft(ctx_factory: cl.CtxFactory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) diff --git a/test/test_codegen.py b/test/test_codegen.py index 04005f68e..019d4b1e2 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -50,7 +50,7 @@ import pytato as pt -def test_basic_codegen(ctx_factory): +def test_basic_codegen(ctx_factory: cl.CtxFactory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) @@ -61,7 +61,7 @@ def test_basic_codegen(ctx_factory): assert (out == x_in * x_in).all() -def test_ctx_bound_execution(ctx_factory): +def test_ctx_bound_execution(ctx_factory: cl.CtxFactory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) @@ -74,7 +74,7 @@ def test_ctx_bound_execution(ctx_factory): assert (out.get() == x_in * x_in).all() -def test_named_clash(ctx_factory): +def test_named_clash(ctx_factory: cl.CtxFactory): x = pt.make_placeholder("x", (5,), np.int64) from pytato.tags import ImplStored, Named @@ -86,7 +86,7 @@ def test_named_clash(ctx_factory): pt.generate_loopy(expr) -def test_scalar_placeholder(ctx_factory): +def test_scalar_placeholder(ctx_factory: cl.CtxFactory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) @@ -99,7 +99,7 @@ def test_scalar_placeholder(ctx_factory): # https://github.com/inducer/pytato/issues/15 @pytest.mark.xfail # shape inference solver: not yet implemented -def test_size_param(ctx_factory): +def test_size_param(ctx_factory: cl.CtxFactory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) @@ -113,7 +113,7 @@ def test_size_param(ctx_factory): @pytest.mark.parametrize("x1_ndim", (1, 2)) @pytest.mark.parametrize("x2_ndim", (1, 2)) -def test_matmul_basic(ctx_factory, x1_ndim, x2_ndim): +def test_matmul_basic(ctx_factory: cl.CtxFactory, x1_ndim, x2_ndim): ctx = ctx_factory() queue = cl.CommandQueue(ctx) @@ -153,7 +153,7 @@ def _do_matmul(x1_shape, x2_shape, queue): @pytest.mark.parametrize("x1_shape", ((7,),)) @pytest.mark.parametrize("x2_shape", ((7,), (7, 3), (7, 7, 4), (7, 7, 7, 9))) -def test_matmul_dimone(ctx_factory, x1_shape, x2_shape): +def test_matmul_dimone(ctx_factory: cl.CtxFactory, x1_shape, x2_shape): ctx = ctx_factory() queue = cl.CommandQueue(ctx) @@ -162,14 +162,14 @@ def test_matmul_dimone(ctx_factory, x1_shape, x2_shape): @pytest.mark.parametrize("x1_shape", ((9, 7, 11, 4), (9, 7, 1, 4), (7, 4))) @pytest.mark.parametrize("x2_shape", ((7, 4, 5), (7, 4, 3))) -def test_matmul_higherdim(ctx_factory, x1_shape, x2_shape): +def test_matmul_higherdim(ctx_factory: cl.CtxFactory, x1_shape, x2_shape): ctx = ctx_factory() queue = cl.CommandQueue(ctx) _do_matmul(x1_shape, x2_shape, queue) -def test_data_wrapper(ctx_factory): +def test_data_wrapper(ctx_factory: cl.CtxFactory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) @@ -189,7 +189,7 @@ def test_data_wrapper(ctx_factory): assert (x_out == x_in).all() -def test_codegen_with_DictOfNamedArrays(ctx_factory): +def test_codegen_with_DictOfNamedArrays(ctx_factory: cl.CtxFactory): ctx = ctx_factory() queue = cl.CommandQueue(ctx) @@ -210,7 +210,7 @@ def test_codegen_with_DictOfNamedArrays(ctx_factory): @pytest.mark.parametrize("shift", (-1, 1, 0, -20, 20)) @pytest.mark.parametrize("axis", (0, 1)) -def test_roll(ctx_factory, shift, axis): +def test_roll(ctx_factory: cl.CtxFactory, shift, axis): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) @@ -226,7 +226,7 @@ def test_roll(ctx_factory, shift, axis): @pytest.mark.parametrize("axes", ( (), (0, 1), (1, 0), (0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0), (2, 0, 1), (2, 1, 0))) -def test_axis_permutation(ctx_factory, axes): +def test_axis_permutation(ctx_factory: cl.CtxFactory, axes): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) @@ -242,7 +242,7 @@ def test_axis_permutation(ctx_factory, axes): assert_allclose_to_numpy(pt.transpose(x, axes), queue) -def test_transpose(ctx_factory): +def test_transpose(ctx_factory: cl.CtxFactory): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) @@ -271,7 +271,7 @@ def wrapper(*args): "greater", "greater_equal", "logical_and", "logical_or")) @pytest.mark.parametrize("reverse", (False, True)) -def test_scalar_array_binary_arith(ctx_factory, which, reverse): +def test_scalar_array_binary_arith(ctx_factory: cl.CtxFactory, which, reverse): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) not_valid_in_complex = which in ["equal", "not_equal", "less", "less_equal", @@ -325,7 +325,7 @@ def test_scalar_array_binary_arith(ctx_factory, which, reverse): "greater", "greater_equal", "logical_or", "logical_and")) @pytest.mark.parametrize("reverse", (False, True)) -def test_array_array_binary_arith(ctx_factory, which, reverse): +def test_array_array_binary_arith(ctx_factory: cl.CtxFactory, which, reverse): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) not_valid_in_complex = which in ["equal", "not_equal", "less", "less_equal", @@ -377,7 +377,7 @@ def test_array_array_binary_arith(ctx_factory, which, reverse): @pytest.mark.parametrize("which", ("__and__", "__or__", "__xor__")) -def test_binary_logic(ctx_factory, which): +def test_binary_logic(ctx_factory: cl.CtxFactory, which): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) @@ -401,7 +401,7 @@ def test_binary_logic(ctx_factory, which): @pytest.mark.parametrize("which", ("neg", "pos")) -def test_unary_arith(ctx_factory, which): +def test_unary_arith(ctx_factory: cl.CtxFactory, which: str): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) @@ -426,6 +426,38 @@ def test_unary_arith(ctx_factory, which): assert np.array_equal(out, out_ref) +def test_astype(ctx_factory: cl.CtxFactory): + cl_ctx = ctx_factory() + queue = cl.CommandQueue(cl_ctx) + + for from_dtype in ARITH_DTYPES: + from_dtype = np.dtype(from_dtype) + x_orig = np.array( + [1, 2.3333333333333333333333, + 3.3333333333333333333333333j, 4, 5, + -234743, 23984723984720394872], + dtype=np.complex128, + ).astype(from_dtype) + + x_orig_sym = pt.make_data_wrapper(x_orig) + + for to_dtype in ARITH_DTYPES: + to_dtype = np.dtype(to_dtype) + + print(from_dtype, to_dtype) + x = x_orig.astype(to_dtype) + try: + x_sym = x_orig_sym.astype(to_dtype) + except NotImplementedError: + continue + + prog = pt.generate_loopy(x_sym) + + _, (x_sym_eval,) = prog(queue) + + assert np.array_equal(x, x_sym_eval) + + def generate_test_slices_for_dim(dim_bound): # Include scalars to test indexing. for i in range(dim_bound): @@ -441,7 +473,7 @@ def generate_test_slices(shape): @pytest.mark.parametrize("shape", [(3,), (2, 2), (1, 2, 1)]) -def test_slice(ctx_factory, shape): +def test_slice(ctx_factory: cl.CtxFactory, shape): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) @@ -469,7 +501,7 @@ def test_slice(ctx_factory, shape): @pytest.mark.parametrize("input_dims", (1, 2, 3)) -def test_stack(ctx_factory, input_dims): +def test_stack(ctx_factory: cl.CtxFactory, input_dims): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) @@ -487,7 +519,7 @@ def test_stack(ctx_factory, input_dims): assert_allclose_to_numpy(pt.stack((x, y), axis=axis), queue) -def test_concatenate(ctx_factory): +def test_concatenate(ctx_factory: cl.CtxFactory): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) @@ -511,7 +543,7 @@ def test_concatenate(ctx_factory): @pytest.mark.parametrize("newshape", [ *_SHAPES, (-1,), (-1, 6), (4, 9), (9, -1), (36, -1), 36]) @pytest.mark.parametrize("order", ["C", "F"]) -def test_reshape(ctx_factory, oldshape, newshape, order): +def test_reshape(ctx_factory: cl.CtxFactory, oldshape, newshape, order): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) @@ -539,7 +571,7 @@ def test_dict_of_named_array_codegen_avoids_recomputation(): assert ("y" in knl.id_to_insn["z_store"].read_dependency_names()) -def test_dict_to_loopy_kernel(ctx_factory): +def test_dict_to_loopy_kernel(ctx_factory: cl.CtxFactory): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) @@ -573,7 +605,7 @@ def test_only_deps_as_knl_args(): @pytest.mark.parametrize("function_name", ("abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan", "sinh", "cosh", "tanh", "exp", "log", "log10", "sqrt", "conj", "__abs__", "real", "imag")) -def test_math_functions(ctx_factory, dtype, function_name): +def test_math_functions(ctx_factory: cl.CtxFactory, dtype, function_name): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) @@ -603,7 +635,7 @@ def test_math_functions(ctx_factory, dtype, function_name): @pytest.mark.parametrize("dtype", (np.float32, np.float64, np.complex128)) @pytest.mark.parametrize("function_name", ("arctan2",)) -def test_binary_math_functions(ctx_factory, dtype, function_name): +def test_binary_math_functions(ctx_factory: cl.CtxFactory, dtype, function_name): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) @@ -628,7 +660,7 @@ def test_binary_math_functions(ctx_factory, dtype, function_name): @pytest.mark.parametrize("dtype", (np.int32, np.int64, np.float32, np.float64)) -def test_full_zeros_ones(ctx_factory, dtype): +def test_full_zeros_ones(ctx_factory: cl.CtxFactory, dtype): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) @@ -644,7 +676,7 @@ def test_full_zeros_ones(ctx_factory, dtype): assert (t == 2).all() -def test_passing_bound_arguments_raises(ctx_factory): +def test_passing_bound_arguments_raises(ctx_factory: cl.CtxFactory): queue = cl.CommandQueue(ctx_factory()) x = pt.make_data_wrapper(np.ones(10), tags=frozenset([pt.tags.PrefixNamed("x")])) @@ -663,7 +695,7 @@ def test_passing_bound_arguments_raises(ctx_factory): [(4, 1, 3), (1, 7, 1)], [(4, 1, 3), (1, p.Variable("n")+2, 1)], )) -def test_broadcasting(ctx_factory, shape1, shape2): +def test_broadcasting(ctx_factory: cl.CtxFactory, shape1, shape2): from numpy.random import default_rng from pymbolic.mapper.evaluator import evaluate @@ -690,7 +722,7 @@ def test_broadcasting(ctx_factory, shape1, shape2): @pytest.mark.parametrize("which", ("maximum", "minimum")) -def test_maximum_minimum(ctx_factory, which): +def test_maximum_minimum(ctx_factory: cl.CtxFactory, which): cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) @@ -715,7 +747,7 @@ def _get_rand_with_nans(shape): np.testing.assert_allclose(y, np_func(x1_in, x2_in), rtol=1e-6) -def test_call_loopy(ctx_factory): +def test_call_loopy(ctx_factory: cl.CtxFactory): from pytato.loopy import call_loopy cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) @@ -737,7 +769,7 @@ def test_call_loopy(ctx_factory): assert (z_out == 40*(x_in.sum(axis=1))).all() -def test_call_loopy_with_same_callee_names(ctx_factory): +def test_call_loopy_with_same_callee_names(ctx_factory: cl.CtxFactory): from pytato.loopy import call_loopy queue = cl.CommandQueue(ctx_factory()) @@ -767,7 +799,7 @@ def test_call_loopy_with_same_callee_names(ctx_factory): np.testing.assert_allclose(out_dict["nueve_u"], 9*u_in) -def test_exprs_with_named_arrays(ctx_factory): +def test_exprs_with_named_arrays(ctx_factory: cl.CtxFactory): queue = cl.CommandQueue(ctx_factory()) x_in = np.random.rand(10, 4) x = pt.make_data_wrapper(x_in) @@ -778,7 +810,7 @@ def test_exprs_with_named_arrays(ctx_factory): np.testing.assert_allclose(out, 42*x_in) -def test_call_loopy_with_parametric_sizes(ctx_factory): +def test_call_loopy_with_parametric_sizes(ctx_factory: cl.CtxFactory): x_in = np.random.rand(10, 4) @@ -804,7 +836,7 @@ def test_call_loopy_with_parametric_sizes(ctx_factory): np.testing.assert_allclose(z_out, 42*(x_in.sum(axis=1))) -def test_call_loopy_with_scalar_array_inputs(ctx_factory): +def test_call_loopy_with_scalar_array_inputs(ctx_factory: cl.CtxFactory): from numpy.random import default_rng import loopy as lp @@ -832,7 +864,7 @@ def test_call_loopy_with_scalar_array_inputs(ctx_factory): @pytest.mark.parametrize("axis", (None, 1, 0)) @pytest.mark.parametrize("redn", ("sum", "amax", "amin", "prod", "all", "any")) @pytest.mark.parametrize("shape", [(2, 2), (1, 2, 1), (3, 4, 5)]) -def test_reductions(ctx_factory, axis, redn, shape): +def test_reductions(ctx_factory: cl.CtxFactory, axis, redn, shape): queue = cl.CommandQueue(ctx_factory()) from numpy.random import default_rng @@ -889,7 +921,7 @@ def test_reductions(ctx_factory, axis, redn, shape): ("ij,ij->ij", # broadcasting [(1, 3), (3, 1)]), ])) -def test_einsum(ctx_factory, spec, argshapes): +def test_einsum(ctx_factory: cl.CtxFactory, spec, argshapes): ctx = ctx_factory() queue = cl.CommandQueue(ctx) @@ -904,7 +936,7 @@ def test_einsum(ctx_factory, spec, argshapes): np.testing.assert_allclose(np_out, pt_out) -def test_einsum_with_parameterized_shapes(ctx_factory): +def test_einsum_with_parameterized_shapes(ctx_factory: cl.CtxFactory): ctx = ctx_factory() cq = cl.CommandQueue(ctx) @@ -934,7 +966,8 @@ def _get_x_shape(_m, n_): np.testing.assert_allclose(np_out, pt_out) -def test_arguments_passing_to_loopy_kernel_for_non_dependent_vars(ctx_factory): +def test_arguments_passing_to_loopy_kernel_for_non_dependent_vars( + ctx_factory: cl.CtxFactory): from numpy.random import default_rng ctx = ctx_factory() cq = cl.CommandQueue(ctx) @@ -949,7 +982,7 @@ def test_arguments_passing_to_loopy_kernel_for_non_dependent_vars(ctx_factory): np.testing.assert_allclose(out, 0) -def test_call_loopy_shape_inference1(ctx_factory): +def test_call_loopy_shape_inference1(ctx_factory: cl.CtxFactory): from numpy.random import default_rng import loopy as lp @@ -981,7 +1014,7 @@ def test_call_loopy_shape_inference1(ctx_factory): + np.arange(3))) -def test_call_loopy_shape_inference2(ctx_factory): +def test_call_loopy_shape_inference2(ctx_factory: cl.CtxFactory): from numpy.random import default_rng import loopy as lp @@ -1021,7 +1054,7 @@ def test_call_loopy_shape_inference2(ctx_factory): @pytest.mark.parametrize("n", [4, 3, 5]) @pytest.mark.parametrize("m", [2, 7, None]) @pytest.mark.parametrize("k", [-2, -1, 0, 1, 2]) -def test_eye(ctx_factory, n, m, k): +def test_eye(ctx_factory: cl.CtxFactory, n, m, k): ctx = ctx_factory() cq = cl.CommandQueue(ctx) @@ -1034,7 +1067,7 @@ def test_eye(ctx_factory, n, m, k): np.testing.assert_allclose(out, np_eye) -def test_arange(ctx_factory): +def test_arange(ctx_factory: cl.CtxFactory): ctx = ctx_factory() cq = cl.CommandQueue(ctx) @@ -1096,7 +1129,7 @@ def test_arange(ctx_factory): @pytest.mark.parametrize("which,num_args", ([("maximum", 2), ("minimum", 2), ])) -def test_pt_ops_on_scalar_args_computed_eagerly(ctx_factory, which, num_args): +def test_pt_ops_on_scalar_args_computed_eagerly(which: str, num_args: int): from numpy.random import default_rng rng = default_rng() args = [rng.random() for _ in range(num_args)] @@ -1113,7 +1146,7 @@ def test_pt_ops_on_scalar_args_computed_eagerly(ctx_factory, which, num_args): ((10, 5, 2, 7), (3, 7, 4))])) @pytest.mark.parametrize("a_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("b_dtype", [np.float32, np.complex64]) -def test_dot(ctx_factory, a_shape, b_shape, a_dtype, b_dtype): +def test_dot(ctx_factory: cl.CtxFactory, a_shape, b_shape, a_dtype, b_dtype): from numpy.random import default_rng rng = default_rng() ctx = ctx_factory() @@ -1144,7 +1177,7 @@ def test_dot(ctx_factory, a_shape, b_shape, a_dtype, b_dtype): ((10, 4), (4, 10))])) @pytest.mark.parametrize("a_dtype", [np.float32, np.complex64]) @pytest.mark.parametrize("b_dtype", [np.float32, np.complex64]) -def test_vdot(ctx_factory, a_shape, b_shape, a_dtype, b_dtype): +def test_vdot(ctx_factory: cl.CtxFactory, a_shape, b_shape, a_dtype, b_dtype): from numpy.random import default_rng rng = default_rng() ctx = ctx_factory() @@ -1171,7 +1204,7 @@ def test_vdot(ctx_factory, a_shape, b_shape, a_dtype, b_dtype): np.testing.assert_allclose(np_result, pt_result, rtol=1e-6) -def test_reduction_adds_deps(ctx_factory): +def test_reduction_adds_deps(ctx_factory: cl.CtxFactory): from numpy.random import default_rng ctx = ctx_factory() @@ -1192,7 +1225,7 @@ def test_reduction_adds_deps(ctx_factory): out_dict["z"]) -def test_broadcast_to(ctx_factory): +def test_broadcast_to(ctx_factory: cl.CtxFactory): from numpy.random import default_rng ctx = ctx_factory() @@ -1220,7 +1253,7 @@ def test_broadcast_to(ctx_factory): x_brdcst) -def test_advanced_indexing_with_broadcasting(ctx_factory): +def test_advanced_indexing_with_broadcasting(ctx_factory: cl.CtxFactory): from numpy.random import default_rng ctx = ctx_factory() @@ -1256,7 +1289,7 @@ def test_advanced_indexing_with_broadcasting(ctx_factory): np.testing.assert_allclose(pt_out, x_in[idx1_in, ..., idx2_in]) -def test_advanced_indexing_fuzz(ctx_factory): +def test_advanced_indexing_fuzz(ctx_factory: cl.CtxFactory): from numpy.random import default_rng ctx = ctx_factory() @@ -1315,7 +1348,7 @@ def test_advanced_indexing_fuzz(ctx_factory): f" indices={pt_indices}")) -def test_reshape_on_scalars(ctx_factory): +def test_reshape_on_scalars(ctx_factory: cl.CtxFactory): # Reported by alexfikl # See https://github.com/inducer/pytato/issues/157 from numpy.random import default_rng @@ -1339,7 +1372,7 @@ def test_reshape_on_scalars(ctx_factory): assert_allclose_to_numpy(pt.reshape(x, (192, 168, 0, 1)), cq) -def test_materialize_reduces_flops(ctx_factory): +def test_materialize_reduces_flops(ctx_factory: cl.CtxFactory): x1 = pt.make_placeholder("x1", (10, 4), np.float64) x2 = pt.make_placeholder("x2", (10, 4), np.float64) x3 = pt.make_placeholder("x3", (10, 4), np.float64) @@ -1365,7 +1398,7 @@ def test_materialize_reduces_flops(ctx_factory): assert good_flops == (bad_flops - 80) -def test_named_temporaries(ctx_factory): +def test_named_temporaries(ctx_factory: cl.CtxFactory): x = pt.make_placeholder("x", (10, 4), np.float32) y = pt.make_placeholder("y", (10, 4), np.float32) tmp1 = 2 * x + 3 * y @@ -1394,7 +1427,7 @@ def mark_materialized_nodes_as_cse(ary: pt.Array | pt.AbstractResultWithNamedArr @pytest.mark.parametrize("shape", [(1, 3, 1), (1, 1), (2, 2, 3)]) -def test_squeeze(ctx_factory, shape): +def test_squeeze(ctx_factory: cl.CtxFactory, shape): ctx = ctx_factory() cq = cl.CommandQueue(ctx) @@ -1408,7 +1441,7 @@ def test_squeeze(ctx_factory, shape): np.testing.assert_allclose(pt_result.shape, np_result) -def test_random_dag_against_numpy(ctx_factory): +def test_random_dag_against_numpy(ctx_factory: cl.CtxFactory): ctx = ctx_factory() cq = cl.CommandQueue(ctx) @@ -1439,7 +1472,7 @@ def test_random_dag_against_numpy(ctx_factory): assert np.allclose(pt_result["result"], ref_result) -def test_assume_non_negative_indirect_address(ctx_factory): +def test_assume_non_negative_indirect_address(ctx_factory: cl.CtxFactory): from numpy.random import default_rng from pytato.scalar_expr import WalkMapper @@ -1546,7 +1579,7 @@ def test_array_tags_propagated_to_loopy(): .tags_of_type(BazTag)) -def test_scalars_are_typed(ctx_factory): +def test_scalars_are_typed(ctx_factory: cl.CtxFactory): # See https://github.com/inducer/pytato/issues/246 ctx = ctx_factory() cq = cl.CommandQueue(ctx) @@ -1559,7 +1592,7 @@ def test_scalars_are_typed(ctx_factory): np.testing.assert_allclose(pt_out, np_out) -def test_regression_reduction_in_conditional(ctx_factory): +def test_regression_reduction_in_conditional(ctx_factory: cl.CtxFactory): # Regression test for https://github.com/inducer/pytato/pull/255 # which was ultimately caused by https://github.com/inducer/loopy/issues/533 # Reproducer from @@ -1596,7 +1629,7 @@ def get_np_input_args(): np.testing.assert_allclose(pt_result, np_result) -def test_zero_size_cl_array_dedup(ctx_factory): +def test_zero_size_cl_array_dedup(ctx_factory: cl.CtxFactory): # At pytato@0d8b909 this regression would fail as # 'deduplicate_data_wrappers' could not handle 0-long buffers import pyopencl.array as cla @@ -1679,7 +1712,7 @@ def test_deterministic_codegen(): # }}} -def test_no_computation_for_empty_arrays(ctx_factory): +def test_no_computation_for_empty_arrays(ctx_factory: cl.CtxFactory): # fails at faee1d2f4ffabed96e758a2c0fbd1b1fc1e78719 import pyopencl.array as cla @@ -1698,7 +1731,7 @@ def test_no_computation_for_empty_arrays(ctx_factory): assert not bprg.program.default_entrypoint.instructions -def test_expand_dims(ctx_factory): +def test_expand_dims(ctx_factory: cl.CtxFactory): from numpy.random import default_rng ntests = 50 @@ -1727,7 +1760,7 @@ def test_expand_dims(ctx_factory): np.testing.assert_allclose(np_output, pt_output) -def test_two_rolls(ctx_factory): +def test_two_rolls(ctx_factory: cl.CtxFactory): # see https://github.com/inducer/pytato/issues/341 cl_ctx = ctx_factory() cq = cl.CommandQueue(cl_ctx) @@ -1742,7 +1775,7 @@ def test_two_rolls(ctx_factory): np.testing.assert_allclose(np_out, pt_out) -def test_lp_substitution_result(ctx_factory): +def test_lp_substitution_result(ctx_factory: cl.CtxFactory): from pytato.target.loopy import ImplSubstitution cl_ctx = ctx_factory() @@ -1770,7 +1803,7 @@ def test_lp_substitution_result(ctx_factory): np.testing.assert_allclose(np.einsum("ij,j->i", 2*a_np, 3*x_np), pt_eval_out) -def test_rewrite_einsums_with_no_broadcasts(ctx_factory): +def test_rewrite_einsums_with_no_broadcasts(ctx_factory: cl.CtxFactory): ctx = ctx_factory() cq = cl.CommandQueue(ctx) @@ -1789,7 +1822,7 @@ def test_rewrite_einsums_with_no_broadcasts(ctx_factory): np.testing.assert_allclose(ref_out, out) -def test_no_redundant_stores_with_impl_stored(ctx_factory): +def test_no_redundant_stores_with_impl_stored(ctx_factory: cl.CtxFactory): # See ctx = ctx_factory() cq = cl.CommandQueue(ctx) @@ -1806,7 +1839,8 @@ def test_no_redundant_stores_with_impl_stored(ctx_factory): np.testing.assert_allclose(prg(cq, x=x_np)[1][0], 2*x_np) -def test_placeholders_do_not_diverge_after_removing_impl_stored(ctx_factory): +def test_placeholders_do_not_diverge_after_removing_impl_stored( + ctx_factory: cl.CtxFactory): # Note: An earlier attempt at fixing # would create multiple # instances of placeholders in the graph leading to incoherent codes. @@ -1847,7 +1881,7 @@ def _get_mask_array_idx(*idxs): ) -def test_pad(ctx_factory): +def test_pad(ctx_factory: cl.CtxFactory): ctx = ctx_factory() cq = cl.CommandQueue(ctx) @@ -1912,7 +1946,7 @@ def test_pad(ctx_factory): np.testing.assert_allclose(np_out * mask_array, pt_out * mask_array) -def test_function_call(ctx_factory, visualize=False): +def test_function_call(ctx_factory: cl.CtxFactory, visualize=False): from functools import partial cl_ctx = ctx_factory() cq = cl.CommandQueue(cl_ctx) @@ -1962,7 +1996,7 @@ def build_expression(tracer): np.testing.assert_allclose(outputs[key], expected[key]) -def test_nested_function_calls(ctx_factory): +def test_nested_function_calls(ctx_factory: cl.CtxFactory): from functools import partial ctx = ctx_factory() @@ -2008,7 +2042,7 @@ def call_bar(tracer, x, y) -> pt.Array: np.testing.assert_allclose(result_out[k], expect_out[k]) -def test_pow_arg_casting(ctx_factory): +def test_pow_arg_casting(ctx_factory: cl.CtxFactory): # Check that pow() arguments are not typecast from int ctx = ctx_factory() cq = cl.CommandQueue(ctx) @@ -2067,7 +2101,7 @@ def test_pow_arg_casting(ctx_factory): (float, np.float32, np.float64) -def test_forcevalueargtag(ctx_factory): +def test_forcevalueargtag(ctx_factory: cl.CtxFactory): ctx = ctx_factory() cq = cl.CommandQueue(ctx) diff --git a/test/test_distributed.py b/test/test_distributed.py index 65214c4b0..7de1607ff 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -85,7 +85,7 @@ def test_distributed_execution_basic(): run_test_with_mpi(2, _do_test_distributed_execution_basic) -def _do_test_distributed_execution_basic(ctx_factory): +def _do_test_distributed_execution_basic(ctx_factory: cl.CtxFactory): from mpi4py import MPI # pylint: disable=import-error comm = MPI.COMM_WORLD @@ -252,7 +252,7 @@ class _RandomDAGTag: pass -def _do_test_distributed_execution_random_dag(ctx_factory): +def _do_test_distributed_execution_random_dag(ctx_factory: cl.CtxFactory): from mpi4py import MPI # pylint: disable=import-error comm = MPI.COMM_WORLD @@ -353,7 +353,7 @@ def gen_comm(rdagc): # {{{ test DAG with no comm nodes -def _test_dag_with_no_comm_nodes_inner(ctx_factory): +def _test_dag_with_no_comm_nodes_inner(ctx_factory: cl.CtxFactory): from mpi4py import MPI # pylint: disable=import-error from numpy.random import default_rng comm = MPI.COMM_WORLD @@ -390,7 +390,7 @@ def test_dag_with_no_comm_nodes(): # {{{ test DAG with duplicated output arrays -def _test_dag_with_duplicated_output_arrays_inner(ctx_factory): +def _test_dag_with_duplicated_output_arrays_inner(ctx_factory: cl.CtxFactory): from mpi4py import MPI # pylint: disable=import-error from numpy.random import default_rng comm = MPI.COMM_WORLD @@ -426,7 +426,7 @@ def test_dag_with_duplicated_output_arrays(): # {{{ test DAG with a receive as an output -def _test_dag_with_recv_as_output_inner(ctx_factory): +def _test_dag_with_recv_as_output_inner(ctx_factory: cl.CtxFactory): from mpi4py import MPI # pylint: disable=import-error from numpy.random import default_rng comm = MPI.COMM_WORLD @@ -473,7 +473,8 @@ def test_dag_with_recv_as_output(): # {{{ test DAG with a materialized array promoted to a part output -def _test_dag_with_materialized_array_promoted_to_part_output_inner(ctx_factory): +def _test_dag_with_materialized_array_promoted_to_part_output_inner( + ctx_factory: cl.CtxFactory): from mpi4py import MPI # pylint: disable=import-error from numpy.random import default_rng comm = MPI.COMM_WORLD @@ -541,7 +542,7 @@ def test_dag_with_materialized_array_promoted_to_part_output(): # {{{ test DAG with multiple send nodes per sent array -def _test_dag_with_multiple_send_nodes_per_sent_array_inner(ctx_factory): +def _test_dag_with_multiple_send_nodes_per_sent_array_inner(ctx_factory: cl.CtxFactory): from mpi4py import MPI # pylint: disable=import-error from numpy.random import default_rng comm = MPI.COMM_WORLD @@ -595,7 +596,7 @@ def test_dag_with_multiple_send_nodes_per_sent_array(): # {{{ test DAG with periodic communication -def _test_dag_with_periodic_communication_inner(ctx_factory): +def _test_dag_with_periodic_communication_inner(ctx_factory: cl.CtxFactory): from mpi4py import MPI # pylint: disable=import-error from numpy.random import default_rng comm = MPI.COMM_WORLD @@ -656,7 +657,7 @@ def test_dag_with_periodic_communication(num_ranks): # {{{ test deterministic partitioning -def _gather_random_dist_partitions(ctx_factory): +def _gather_random_dist_partitions(ctx_factory: cl.CtxFactory): import mpi4py.MPI as MPI comm = MPI.COMM_WORLD @@ -709,7 +710,7 @@ def test_kaushik_mwe(): run_test_with_mpi(2, _do_test_kaushik_mwe) -def _do_test_kaushik_mwe(ctx_factory): +def _do_test_kaushik_mwe(ctx_factory: cl.CtxFactory): # from https://github.com/inducer/pytato/pull/393#issuecomment-1324642248 from mpi4py import MPI @@ -765,7 +766,7 @@ def test_verify_distributed_partition(): run_test_with_mpi(2, _do_verify_distributed_partition) -def _do_verify_distributed_partition(ctx_factory): +def _do_verify_distributed_partition(ctx_factory: cl.CtxFactory): from mpi4py import MPI # pylint: disable=import-error comm = MPI.COMM_WORLD from pytato.distributed.verify import ( @@ -874,7 +875,7 @@ class FooTag2: pass -def test_number_symbolic_tags_bare_classes(ctx_factory): +def test_number_symbolic_tags_bare_classes(ctx_factory: cl.CtxFactory): from mpi4py import MPI # pylint: disable=import-error comm = MPI.COMM_WORLD from pytato.distributed.nodes import make_distributed_recv, staple_distributed_send diff --git a/test/test_linalg.py b/test/test_linalg.py index 28d0585fe..a67e09430 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -27,7 +27,6 @@ """ import sys -from typing import cast import numpy as np @@ -91,7 +90,7 @@ def how_to_distribute( assert ( y_transformed == ( 7 * (A1 @ (2*x1-3*x2)) - + 8 * (cast("pt.Array", pt.sin(A2)) @ (2*x1-3*x2)))) + + 8 * (pt.sin(A2) @ (2*x1-3*x2)))) def test_apply_einsum_distributive_law_2():