From 320f30c45d8ffdb4d23edbd866fc32da455b9dce Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 29 Aug 2025 14:20:18 -0500 Subject: [PATCH] A few types in primitives --- pytential/symbolic/primitives.py | 49 +++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/pytential/symbolic/primitives.py b/pytential/symbolic/primitives.py index 6cf262656..734522ac2 100644 --- a/pytential/symbolic/primitives.py +++ b/pytential/symbolic/primitives.py @@ -38,7 +38,7 @@ from warnings import warn import numpy as np -from typing_extensions import deprecated, override, reveal_type +from typing_extensions import deprecated, override from pymbolic import Expression, ExpressionNode as ExpressionNodeBase, Variable from pymbolic.geometric_algebra import MultiVector, componentwise @@ -57,6 +57,7 @@ from pytools import P, obj_array from pytools.obj_array import ( ObjectArray, + ObjectArray1D, ObjectArray2D, ObjectArrayND, ShapeT, @@ -950,7 +951,6 @@ def shape_operator( # https://en.wikipedia.org/w/index.php?title=Differential_geometry_of_surfaces&oldid=833587563 (E, F), (_F, G) = first_fundamental_form(ambient_dim, dim, dofdesc) (e, f), (_f, g) = second_fundamental_form(ambient_dim, dim, dofdesc) - reveal_type(E) result = np.zeros((2, 2), dtype=object) result[0, 0] = e*G-f*F @@ -2248,7 +2248,8 @@ def Dp(kernel, *args, **kwargs): def tangential_onb( ambient_dim: int, dim: int | None = None, - dofdesc: DOFDescriptorLike = None): + dofdesc: DOFDescriptorLike = None + ) -> ObjectArray2D[ArithmeticExpression]: """Return a matrix of shape ``(ambient_dim, dim)`` with orthogonal columns spanning the tangential space of the surface of *dofdesc*. """ @@ -2260,39 +2261,51 @@ def tangential_onb( # {{{ Gram-Schmidt - orth_pd_mat = np.zeros_like(pd_mat) + orth_pd_mat = np.zeros_like(obj_array.to_numpy(pd_mat)) for k in range(pd_mat.shape[1]): avec = pd_mat[:, k] + q = avec for j in range(k): - q = q - np.dot(avec, orth_pd_mat[:, j])*orth_pd_mat[:, j] + q = q - (avec @ orth_pd_mat[:, j])*orth_pd_mat[:, j] q = cse(q, f"q{k}") - orth_pd_mat[:, k] = cse(q/sqrt(np.sum(q**2)), f"orth_pd_vec{k}_") + orth_pd_mat[:, k] = cse(q/sqrt(obj_array.sum(q**2)), f"orth_pd_vec{k}_") # }}} - return orth_pd_mat + return obj_array.from_numpy(orth_pd_mat) -def xyz_to_tangential(xyz_vec, dofdesc: DOFDescriptorLike = None): +def xyz_to_tangential( + xyz_vec: ObjectArray1D[ArithmeticExpression], + dofdesc: DOFDescriptorLike = None + ) -> ObjectArray1D[ArithmeticExpression]: ambient_dim = len(xyz_vec) tonb = tangential_onb(ambient_dim, dofdesc=dofdesc) return obj_array.new_1d([ - np.dot(tonb[:, i], xyz_vec) + tonb[:, i] @ xyz_vec for i in range(ambient_dim - 1) ]) -def tangential_to_xyz(tangential_vec, dofdesc: DOFDescriptorLike = None): +def tangential_to_xyz( + tangential_vec: ObjectArray1D[ArithmeticExpression], + dofdesc: DOFDescriptorLike = None + ) -> ObjectArray1D[ArithmeticExpression]: ambient_dim = len(tangential_vec) + 1 tonb = tangential_onb(ambient_dim, dofdesc=dofdesc) - return sum( - tonb[:, i] * tangential_vec[i] - for i in range(ambient_dim - 1)) + return cast( + "ObjectArray1D[ArithmeticExpression]", + sum( + tonb[:, i] * tangential_vec[i] + for i in range(ambient_dim - 1))) -def project_to_tangential(xyz_vec, dofdesc: DOFDescriptorLike = None): +def project_to_tangential( + xyz_vec: ObjectArray1D[ArithmeticExpression], + dofdesc: DOFDescriptorLike = None + ) -> ObjectArray1D[ArithmeticExpression]: return tangential_to_xyz( cse(xyz_to_tangential(xyz_vec, dofdesc)), dofdesc) @@ -2319,14 +2332,16 @@ def n_cross(vec, dofdesc: DOFDescriptorLike = None): return cross(normal(3, dofdesc=dofdesc).as_vector(), vec) -def div(vec): +def div(vec: ObjectArray1D[ArithmeticExpression]) -> ArithmeticExpression: ambient_dim = len(vec) return sum( dd_axis(iaxis, ambient_dim, vec[iaxis]) for iaxis in range(ambient_dim)) -def curl(vec): +def curl( + vec: ObjectArray1D[ArithmeticExpression] + ) -> ObjectArray1D[ArithmeticExpression]: from pytools import levi_civita return obj_array.new_1d([ @@ -2338,7 +2353,7 @@ def curl(vec): # }}} -def pretty(expr): +def pretty(expr: Operand) -> str: # Doesn't quite belong here, but this is exposed to the user as # "pytential.sym", so in here it goes.