Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 112 additions & 23 deletions proposal2a.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,36 @@ This can probably be just about any integral type.
```C
typedef enum
{
XXX_TYPE_F32,
XXX_TYPE_F64,
XXX_TYPE_C32,
XXX_TYPE_C64,
/* maybe these could be a bitfield, but are there enough bits??? */
/* Reserved for standard 0x0 - 0x100 (for example) */
XXX_TYPE_F32, // required
XXX_TYPE_F64, // required
XXX_TYPE_C32, // required
XXX_TYPE_C64, // required
XXX_TYPE_F16, // required?
XXX_TYPE_BF16, // required?
XXX_TYPE_F8,
XXX_TYPE_BF8,
XXX_TYPE_I32,
XXX_TYPE_U32,
XXX_TYPE_I16,
XXX_TYPE_U16,
XXX_TYPE_I8,
XXX_TYPE_U8,
/* Available for implementers 0x100 - 0x1000 (for example) */
...
} XXX_datatype;

typedef enum
{
/* Implementations may use more precise computational type */
/* Reserved for standard 0x2000 - 0x2100 (for example) */
XXX_TYPE_F32_F32_ACCUM_F32 = XXX_TYPE_F32,
XXX_TYPE_F64_F64_ACCUM_F64 = XXX_TYPE_F64,
...,
XXX_TYPE_LOWER, /* narrowest of input precisions */ /* should this be part of attr's */ /* should there be a truly neutral default (maybe HW dependent)? */
XXX_TYPE_HIGHER, /* widest of input precisions */
/* Available for implementers 0x2100 - 0x3000 (for example) */
...
} XXX_comp_datatype;
```
Expand All @@ -32,18 +52,34 @@ Enumerations for the supported storage and computational datatypes. Not all comb
```C
typedef /* unspecified */ XXX_error; // Should be a trivial type, e.g. "int"

int XXX_error_check(XXX_error err); // return non-zero on error

const char* XXX_error_explain(XXX_error err);

void XXX_error_clear(XXX_error err);
/*
* Required errors:
* - Invalid values (negative lengths, same extent for shared dimension)
* - Null pointers (except 0-dimensional [or maybe 1+-dimensional is required?])
* - If D == C (or XXX_IN_PLACE), stride_D_XXX are ignored (can be NULL)
* - Invocation failure (generic failure)?
*
* Should some other information be available, e.g. out-of-memory so user could try again later.
*/

// The error explain function should not allocate the error string itself
// for security concerns.
// Adapted from the function MPI_Error_string
XXX_ERROR XXX_error_explain(XXX_ERROR err, char *error_string, int *error_size);

// Additionally one has to define as in MPI a MAX_ERROR_STRING
#define XXX_MAX_ERROR_STRING 512 /* implementation dependent */
```
Error handling --- implementation defined.

```C
typedef /* unspecified */ XXX_attr; // Requires initialization. E.g. "struct XXX_attr_internal*"
typedef int32_t XXX_key; // Some values should be reserved for standardization

/*
* Potential keys:
* - Execution plan (pointer to object)

XXX_error XXX_attr_init(XXX_attr* attr);

XXX_error XXX_attr_destroy(XXX_attr* attr);
Expand All @@ -59,38 +95,91 @@ Implementation defined (and maybe some standard) attributes, loosely based on MP
```C
// Unary and binary element-wise operations (transpose, scale, norm, reduction, etc.) should also be defined!

// Element-wise ops on A, B, and AB are very important for machine learning.
// Can this functionality be required in the interface without requiring JIT????

// Compute D_{idx_D} = alpha * A_{idx_A} * B_{idx_B} + beta * C_{idx_C}
// Here, plan creation is a required part of the API

typedef /* unspecified */ XXX_plan; // probably pointer to struct

XXX_error
XXX_contract(const void* alpha,
XXX_contract_plan(
XXX_datatype type_alpha,
const void* A,
XXX_datatype type_A,
int nmode_A,
const XXX_extent* shape_A,
const XXX_stride* stride_A,
const XXX_index* idx_A,
const void* B,
XXX_datatype type_B,
int nmode_B,
const XXX_extent* shape_B,
const XXX_stride* stride_B,
const XXX_index* idx_B,
const void* beta,
XXX_datatype type_beta,
const void* C,
XXX_datatype type_C,
int nmode_C,
const XXX_extent* shape_C,
const XXX_stride* stride_C,
const XXX_index* idx_C,
const XXX_extent* shape_C,
const XXX_stride* stride_C,
const XXX_index* idx_C, // users should specify C twice for in-place
XXX_datatype type_D, // instead, could C or D be NULL?
const XXX_stride* stride_D, // if C == D, do we also need nmode_D, shape_D, etc.?
XXX_comp_datatype type_comp,
XXX_plan plan,
XXX_attr attr);

XXX_error
XXX_contract_execute(
const void* alpha,
const void* A,
const void* B,
const void* beta,
const void* C,
void* D,
XXX_datatype type_D,
int nmode_D,
const XXX_extent* shape_D,
const XXX_stride* stride_D,
const XXX_index* idx_D,
XXX_comp_datatype comp_type,
XXX_plan plan);

// Batched tensor contraction (TBD)

XXX_error
XXX_contract_batched(
int batch_size,
int nmode_M,
const XXX_extent* shape_M,
int nmode_N,
const XXX_extent* shape_N,
int nmode_K,
const XXX_extent* shape_K,
int nmode_L,
const XXX_extent* shape_L,
const void* alpha,
XXX_datatype type_alpha,
const void** A,
XXX_datatype type_A,
const XXX_stride* stride_A_M,
const XXX_stride* stride_A_K,
const XXX_stride* stride_A_L,
const void** B,
XXX_datatype type_B,
const XXX_stride* stride_B_K,
const XXX_stride* stride_B_N,
const XXX_stride* stride_B_L,
const void* beta,
XXX_datatype type_beta,
const void** C,
XXX_datatype type_C,
const XXX_stride* stride_C_M,
const XXX_stride* stride_C_N,
const XXX_stride* stride_C_L,
void** D, // users should specify C twice for in-place
XXX_datatype type_D, // instead, could C or D be NULL?
const XXX_stride* stride_D_M, // if C == D, do we also need nmode_D, shape_D, etc.?
const XXX_stride* stride_D_N, // maybe XXX_IN_PLACE tag for C == D?
const XXX_stride* stride_D_L,
XXX_comp_datatype type_comp,
XXX_attr attr);

/* See also cublasDgemmGroupedBatched for more complex batched interface */


```