Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions dpctl/tensor/_dlpack.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@ from .._backend cimport (
DPCTLSyclDeviceRef,
DPCTLSyclUSMRef,
)
from ._usmarray cimport USM_ARRAY_WRITABLE, usm_ndarray
from ._usmarray cimport (
USM_ARRAY_C_CONTIGUOUS,
USM_ARRAY_F_CONTIGUOUS,
USM_ARRAY_WRITABLE,
usm_ndarray,
)

import ctypes

Expand Down Expand Up @@ -266,6 +271,7 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
cdef int64_t *shape_strides_ptr = NULL
cdef int i = 0
cdef int device_id = -1
cdef int flags = 0
cdef Py_ssize_t element_offset = 0
cdef Py_ssize_t byte_offset = 0
cdef Py_ssize_t si = 1
Expand All @@ -291,14 +297,29 @@ cpdef to_dlpack_capsule(usm_ndarray usm_ary):
for i in range(nd):
shape_strides_ptr[i] = shape_ptr[i]
strides_ptr = usm_ary.get_strides()
flags = usm_ary.flags_
if strides_ptr:
for i in range(nd):
shape_strides_ptr[nd + i] = strides_ptr[i]
else:
si = 1
for i in range(0, nd):
shape_strides_ptr[nd + i] = si
si = si * shape_ptr[i]
if flags & USM_ARRAY_C_CONTIGUOUS:
si = 1
for i in range(nd - 1, -1, -1):
shape_strides_ptr[nd + i] = si
si = si * shape_ptr[i]
elif flags & USM_ARRAY_F_CONTIGUOUS:
si = 1
for i in range(0, nd):
shape_strides_ptr[nd + i] = si
si = si * shape_ptr[i]
else:
stdlib.free(shape_strides_ptr)
stdlib.free(dlm_tensor)
raise BufferError(
"to_dlpack_capsule: Invalid array encountered "
"when building strides"
)

strides_ptr = <Py_ssize_t *>&shape_strides_ptr[nd]

ary_dt = usm_ary.dtype
Expand Down Expand Up @@ -409,10 +430,24 @@ cpdef to_dlpack_versioned_capsule(usm_ndarray usm_ary, bint copied):
for i in range(nd):
shape_strides_ptr[nd + i] = strides_ptr[i]
else:
si = 1
for i in range(0, nd):
shape_strides_ptr[nd + i] = si
si = si * shape_ptr[i]
if flags & USM_ARRAY_C_CONTIGUOUS:
si = 1
for i in range(nd - 1, -1, -1):
shape_strides_ptr[nd + i] = si
si = si * shape_ptr[i]
elif flags & USM_ARRAY_F_CONTIGUOUS:
si = 1
for i in range(0, nd):
shape_strides_ptr[nd + i] = si
si = si * shape_ptr[i]
else:
stdlib.free(shape_strides_ptr)
stdlib.free(dlmv_tensor)
raise BufferError(
"to_dlpack_versioned_capsule: Invalid array encountered "
"when building strides"
)

strides_ptr = <Py_ssize_t *>&shape_strides_ptr[nd]

# this can all be a function for building the dl_tensor
Expand Down
39 changes: 39 additions & 0 deletions dpctl/tests/test_usm_ndarray_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,45 @@ def test_dlpack_capsule_readonly_array_to_kdlcpu():
assert not y1.flags["W"]


def test_to_dlpack_capsule_c_and_f_contig():
try:
x = dpt.asarray(np.random.rand(2, 3))
except dpctl.SyclDeviceCreationError:
pytest.skip("No default device available")

cap = _dlp.to_dlpack_capsule(x)
y = _dlp.from_dlpack_capsule(cap)
assert np.allclose(dpt.asnumpy(x), dpt.asnumpy(y))
assert x.strides == y.strides

x_f = x.T
cap = _dlp.to_dlpack_capsule(x_f)
yf = _dlp.from_dlpack_capsule(cap)
assert np.allclose(dpt.asnumpy(x_f), dpt.asnumpy(yf))
assert x_f.strides == yf.strides
del cap


def test_to_dlpack_versioned_capsule_c_and_f_contig():
try:
x = dpt.asarray(np.random.rand(2, 3))
max_supported_ver = _dlp.get_build_dlpack_version()
except dpctl.SyclDeviceCreationError:
pytest.skip("No default device available")

cap = x.__dlpack__(max_version=max_supported_ver)
y = _dlp.from_dlpack_capsule(cap)
assert np.allclose(dpt.asnumpy(x), dpt.asnumpy(y))
assert x.strides == y.strides

x_f = x.T
cap = x_f.__dlpack__(max_version=max_supported_ver)
yf = _dlp.from_dlpack_capsule(cap)
assert np.allclose(dpt.asnumpy(x_f), dpt.asnumpy(yf))
assert x_f.strides == yf.strides
del cap


def test_used_dlpack_capsule_from_numpy():
get_queue_or_skip()

Expand Down
Loading