Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion renderer/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
ZBuffer,
)

jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]
if hasattr(jax.config, "jax_array"):
jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]

RowIndices = Integer[Array, "row_batches row_batch_size"]
"""Indices of the rows in the buffers to be processed in this batch."""
Expand Down
3 changes: 2 additions & 1 deletion renderer/shader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
Vec4f,
)

jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]
if hasattr(jax.config, "jax_array"):
jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]

ID: TypeAlias = IntV

Expand Down
3 changes: 2 additions & 1 deletion renderer/shaders/depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from ..shader import ID, PerVertex, Shader
from ..types import Vec4f

jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]
if hasattr(jax.config, "jax_array"):
jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]


class DepthExtraInput(NamedTuple):
Expand Down
3 changes: 2 additions & 1 deletion renderer/shaders/gouraud.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from ..shader import ID, PerFragment, PerVertex, Shader
from ..types import BoolV, Colour, FloatV, LightSource, Vec2f, Vec3f, Vec4f

jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]
if hasattr(jax.config, "jax_array"):
jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]


class GouraudExtraInput(NamedTuple):
Expand Down
3 changes: 2 additions & 1 deletion renderer/shaders/gouraud_texture.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from ..shader import ID, MixerOutput, PerFragment, PerVertex, Shader
from ..types import BoolV, Colour, FloatV, LightSource, Texture, Vec2f, Vec3f, Vec4f

jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]
if hasattr(jax.config, "jax_array"):
jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]


class GouraudTextureExtraInput(NamedTuple):
Expand Down
3 changes: 2 additions & 1 deletion renderer/shaders/phong.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from ..shader import ID, MixerOutput, PerFragment, PerVertex, Shader
from ..types import BoolV, Colour, LightSource, Texture, Vec2f, Vec3f, Vec4f

jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]
if hasattr(jax.config, "jax_array"):
jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]


class PhongTextureExtraInput(NamedTuple):
Expand Down
3 changes: 2 additions & 1 deletion renderer/shaders/phong_darboux.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
Vec4f,
)

jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]
if hasattr(jax.config, "jax_array"):
jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]

Triangle3f: TypeAlias = Float[Array, "3 3"]
Triangle2f: TypeAlias = Float[Array, "3 2"]
Expand Down
3 changes: 2 additions & 1 deletion renderer/shaders/phong_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
Vec4f,
)

jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]
if hasattr(jax.config, "jax_array"):
jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]


class PhongReflectionTextureExtraInput(NamedTuple):
Expand Down
3 changes: 2 additions & 1 deletion renderer/shaders/phong_reflection_shadow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
Vec4f,
)

jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]
if hasattr(jax.config, "jax_array"):
jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]


class PhongReflectionShadowTextureExtraInput(NamedTuple):
Expand Down
3 changes: 2 additions & 1 deletion renderer/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@
"Buffers",
]

jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]
if hasattr(jax.config, "jax_array"):
jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]

BoolV: TypeAlias = Bool[Array, ""]
"""JAX Array with single bool value.""" ""
Expand Down
3 changes: 2 additions & 1 deletion test_resources/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from renderer import Tuple, TypeAlias
from renderer.types import FaceIndices, Normals, Texture, UVCoordinates, Vertices

jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]
if hasattr(jax.config, "jax_array"):
jax.config.update("jax_array", True) # pyright: ignore[reportUnknownMemberType]

T2f: TypeAlias = Tuple[float, float]
T3f: TypeAlias = Tuple[float, float, float]
Expand Down