Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ script: rake -f ci/Rakefile ci:run

matrix:
include:
- name: "Ruby 2.6.6 with MXNet 1.7.0 on Python 3.7.2"
env:
- ruby_version=2.6.6
- python_version=3.7.2
- mxnet_version=1.7.0
- name: "Ruby 2.6.0 with MXNet 1.3.1 on Python 3.7.2"
env:
- ruby_version=2.6.0
Expand Down
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# MXNet
[![Build Status](https://travis-ci.org/mrkn/mxnet.rb.svg?branch=master)](https://travis-ci.org/mrkn/mxnet.rb)

Welcome to your new gem! In this directory, you'll find the files you need to be able to package up your Ruby library into a gem. Put your Ruby code in the file `lib/mxnet`. To experiment with that code, run `bin/console` for an interactive prompt.

TODO: Delete this and the text above, and describe your gem
Welcome to the ruby mxnet bindings with access to core mxnet modules including NDArray and Gluon. The latest version tested is [1.7.0](https://mxnet.apache.org/versions/1.7.0/)

## Installation

Expand All @@ -22,7 +20,10 @@ Or install it yourself as:

## Usage

TODO: Write usage instructions here
To experiment with that code, run bin/console for an interactive prompt.

Make sure that the mxnet library files are available by including the .so in `LD_LIBRARY_PATH` environment variable.
`export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path to mxnet library>`, alternatively put your Ruby code in the file `lib/mxnet` of this project.

## Development

Expand Down
4 changes: 2 additions & 2 deletions ci/Rakefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
namespace :ci do
def get_image_name
ruby_version = ENV['ruby_version'] || '2.5.1'
ruby_version = ENV['ruby_version'] || '2.6.6'
python_version = ENV['python_version'] || '3.7.0'
mxnet_version = ENV['mxnet_version'] || '1.2.1.post1'
mxnet_version = ENV['mxnet_version'] || '1.7.0'
return ['mrkn/mxnet-rb-ci', [ruby_version, python_version, mxnet_version].join('-')].join(':')
end

Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile.erb
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
FROM rubylang/ruby:<%= ruby_version %>-bionic

ARG PYTHON_VERSION=3.7.2
ARG MXNET_VERSION=1.3.1
ARG MXNET_VERSION=1.7.0

ENV LANG C.UTF-8
ENV DEBIAN_FRONTEND noninteractive
Expand Down
9 changes: 5 additions & 4 deletions docker/Rakefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ namespace :docker do

def run_all(task)
%w[
2.6.6
2.6.0
2.5.3
2.4.5
Expand All @@ -52,9 +53,9 @@ namespace :docker do
end

task :build do
ruby_version = ENV['ruby_version'] || '2.6.0'
ruby_version = ENV['ruby_version'] || '2.6.6'
python_version = ENV['python_version'] || '3.7.2'
mxnet_version = ENV['mxnet_version'] || '1.3.1'
mxnet_version = ENV['mxnet_version'] || '1.7.0'
docker_build(ruby_version, python_version, mxnet_version)
end

Expand All @@ -65,9 +66,9 @@ namespace :docker do
end

task :push do
ruby_version = ENV['ruby_version'] || '2.6.0'
ruby_version = ENV['ruby_version'] || '2.6.6'
python_version = ENV['python_version'] || '3.7.2'
mxnet_version = ENV['mxnet_version'] || '1.3.1'
mxnet_version = ENV['mxnet_version'] || '1.7.0'
docker_push(ruby_version, python_version, mxnet_version)
end

Expand Down
4 changes: 2 additions & 2 deletions docker/install_ruby.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

set -ex

RUBY_VERSION=${RUBY_VERSION-2.5.1}
RUBY_VERSION=${RUBY_VERSION-2.6.6}
RUBY_MAJOR=$(echo -n $RUBY_VERSION | sed -E 's/\.[0-9]+(-.*)?$//g')
RUBYGEMS_VERSION=${RUBYGEMS_VERSION-2.7.7}
BUNDLER_VERSION=${BUNDLER_VERSION-1.16.4}
BUNDLER_VERSION=${BUNDLER_VERSION-2.1.2}

case $RUBY_VERSION in
2.6.0)
Expand Down
9 changes: 9 additions & 0 deletions ext/mxnet/libmxnet.c
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ init_api_table(VALUE handle)
INIT_API_TABLE_ENTRY(MXNDArrayGetContext);
INIT_API_TABLE_ENTRY(MXNDArrayGetShape);
INIT_API_TABLE_ENTRY(MXNDArrayGetDType);
INIT_API_TABLE_ENTRY(MXNDArrayGetStorageType);
INIT_API_TABLE_ENTRY(MXNDArraySyncCopyFromCPU);
INIT_API_TABLE_ENTRY(MXNDArraySyncCopyToCPU);
INIT_API_TABLE_ENTRY(MXNDArrayAt);
Expand Down Expand Up @@ -132,6 +133,8 @@ imperative_invoke(VALUE mod, VALUE handle, VALUE ndargs, VALUE keys, VALUE vals,
ndargs = rb_convert_type(ndargs, T_ARRAY, "Array", "to_ary");
keys = rb_convert_type(keys, T_ARRAY, "Array", "to_ary");
vals = rb_convert_type(vals, T_ARRAY, "Array", "to_ary");


if (!NIL_P(out) && !RTEST(rb_obj_is_kind_of(out, mxnet_cNDArray))) {
out = rb_convert_type(out, T_ARRAY, "Array", "to_ary");
}
Expand All @@ -145,8 +148,10 @@ imperative_invoke(VALUE mod, VALUE handle, VALUE ndargs, VALUE keys, VALUE vals,

num_params = (int)RARRAY_LEN(keys);
keys_str = rb_str_tmp_new(sizeof(char const *)*num_params);
rb_gc_mark(keys_str);
params_keys = (char const **)RSTRING_PTR(keys_str);
vals_str = rb_str_tmp_new(sizeof(char const *)*num_params);
rb_gc_mark(vals_str);
params_vals = (char const **)RSTRING_PTR(vals_str);
for (i = 0; i < num_params; ++i) {
VALUE key, val;
Expand All @@ -155,6 +160,8 @@ imperative_invoke(VALUE mod, VALUE handle, VALUE ndargs, VALUE keys, VALUE vals,
if (RB_TYPE_P(key, T_SYMBOL)) {
key = rb_sym_to_s(key);
}


params_keys[i] = StringValueCStr(key);

val = rb_String(RARRAY_AREF(vals, i));
Expand Down Expand Up @@ -237,8 +244,10 @@ symbol_creator(VALUE mod, VALUE handle, VALUE args, VALUE kwargs, VALUE keys, VA

num_params = (int)RARRAY_LEN(keys);
keys_str = rb_str_tmp_new(sizeof(char const **)*num_params);
rb_gc_mark(keys_str);
params_keys = (char const **)RSTRING_PTR(keys_str);
vals_str = rb_str_tmp_new(sizeof(char const **)*num_params);
rb_gc_mark(vals_str);
params_vals = (char const **)RSTRING_PTR(vals_str);
for (i = 0; i < num_params; ++i) {
VALUE key, val;
Expand Down
13 changes: 13 additions & 0 deletions ext/mxnet/mxnet_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,17 @@ enum DTypeID {
NUMBER_OF_DTYPE_IDS
};

enum StorageTypeID {
// dense
kDefaultStorage = 0,
// row sparse
kRowSparseStorage = 1,
// csr
kCSRStorage = 2,
NUMBER_OF_STORAGE_TYPE_IDS
};


struct mxnet_api_table {
const char * (* MXGetLastError)();

Expand Down Expand Up @@ -307,6 +318,7 @@ struct mxnet_api_table {
int *num_outputs,
NDArrayHandle **outputs,
int **out_stypes);
int (* MXNDArrayGetStorageType)(NDArrayHandle handle, int *out_storage_type);
};

struct mxnet_api_table *mxnet_get_api_table(void);
Expand All @@ -319,6 +331,7 @@ void *mxnet_get_handle(VALUE obj);
void mxnet_set_handle(VALUE obj, VALUE handle_v);

VALUE mxnet_dtype_id2name(int dtype_id);
VALUE mxnet_storage_type_id2name(int stype_id);
int mxnet_dtype_name2id(VALUE dtype_name);
VALUE mxnet_dtype_name(VALUE id_or_name);

Expand Down
144 changes: 143 additions & 1 deletion ext/mxnet/ndarray.c
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,108 @@ dtype_m_available_p(VALUE mod, VALUE dtype)
return mxnet_dtype_is_available(dtype) ? Qtrue : Qfalse;
}


static size_t storage_type_sizes[NUMBER_OF_STORAGE_TYPE_IDS];
static ID storage_type_name_ids[NUMBER_OF_STORAGE_TYPE_IDS];

VALUE
mxnet_storage_type_id2name(int stype_id)
{
if (0 <= stype_id && stype_id < NUMBER_OF_STORAGE_TYPE_IDS) {
return ID2SYM(storage_type_name_ids[stype_id]);
}

return Qnil;
}

static VALUE
storage_type_m_id2name(VALUE mod, VALUE stype_id_v)
{
int stype_id = NUM2INT(stype_id_v);
return mxnet_storage_type_id2name(stype_id);
}

int
mxnet_storage_type_name2id(VALUE stype_name)
{
ID stype_name_id;
int i;

if (!RB_TYPE_P(stype_name, T_SYMBOL)) {
stype_name = rb_to_symbol(StringValue(stype_name));
}
stype_name_id = SYM2ID(stype_name);

for (i = 0; i < NUMBER_OF_STORAGE_TYPE_IDS; ++i) {
if (storage_type_name_ids[i] == stype_name_id) {
return i;
}
}

return -1;
}

static VALUE
storage_type_m_name2id(VALUE mod, VALUE stype_name)
{
int stype_id;
stype_id = mxnet_storage_type_name2id(stype_name);
return stype_id == -1 ? Qnil : INT2NUM(stype_id);
}

VALUE
mxnet_storage_type_name(VALUE id_or_name)
{
int stype_id;

if (RB_INTEGER_TYPE_P(id_or_name)) {
stype_id = NUM2INT(id_or_name);
}
else {
stype_id = mxnet_storage_type_name2id(id_or_name);
}
if (0 <= stype_id && stype_id < NUMBER_OF_STORAGE_TYPE_IDS)
return mxnet_storage_type_id2name(stype_id);

return Qnil;
}

static VALUE
storage_type_m_name(VALUE mod, VALUE id_or_name)
{
return mxnet_storage_type_name(id_or_name);
}

int
mxnet_storage_type_is_available(VALUE stype)
{
int stype_id;
if (RB_INTEGER_TYPE_P(stype)) {
stype_id = NUM2INT(stype);
}
else {
if (!RB_TYPE_P(stype, T_SYMBOL)) {
VALUE str = rb_check_convert_type(stype, T_STRING, "String", "to_str");
if (NIL_P(str)) {
rb_raise(rb_eTypeError, "Invalid type for storage type (%"PRIsVALUE")", stype);
}
stype = str;
}
stype_id = mxnet_storage_type_name2id(stype);
}
return 0 <= stype_id && stype_id < NUMBER_OF_STORAGE_TYPE_IDS;
}

static VALUE
storage_type_m_available_p(VALUE mod, VALUE stype)
{
return mxnet_storage_type_is_available(stype) ? Qtrue : Qfalse;
}





static void
ndarray_free(void *ptr)
{
Expand Down Expand Up @@ -426,6 +528,24 @@ ndarray_get_dtype(VALUE obj)
return mxnet_dtype_id2name(dtype_id);
}

static int
ndarray_get_storage_type_id(VALUE obj)
{
NDArrayHandle handle;
int stype_id;

handle = mxnet_ndarray_get_handle(obj);
CHECK_CALL(MXNET_API(MXNDArrayGetStorageType)(handle, &stype_id));
return stype_id;
}

static VALUE
ndarray_get_storage_type(VALUE obj)
{
int stype_id = ndarray_get_storage_type_id(obj);
return mxnet_storage_type_id2name(stype_id);
}

VALUE
mxnet_ndarray_get_shape(VALUE obj)
{
Expand Down Expand Up @@ -741,7 +861,7 @@ ndarray_wait_to_read(VALUE obj)
void
mxnet_init_ndarray(void)
{
VALUE cNDArray, mDType;
VALUE cNDArray, mDType, mStorageType;

cNDArray = rb_const_get_at(mxnet_mMXNet, rb_intern("NDArray"));

Expand All @@ -754,6 +874,7 @@ mxnet_init_ndarray(void)
/* TODO: rb_define_singleton_method(cNDArray, "load_from_buffer", ndarray_s_load_from_buffer, 1); */

rb_define_method(cNDArray, "dtype", ndarray_get_dtype, 0);
rb_define_method(cNDArray, "stype", ndarray_get_storage_type, 0);
rb_define_method(cNDArray, "shape", mxnet_ndarray_get_shape, 0);
rb_define_method(cNDArray, "reshape", ndarray_reshape, 1);
rb_define_method(cNDArray, "grad", ndarray_grad, 0);
Expand Down Expand Up @@ -790,4 +911,25 @@ mxnet_init_ndarray(void)
INIT_DTYPE(kInt64, int64_t, "int64");

#undef INIT_DTYPE


mStorageType = rb_define_module_under(mxnet_mMXNet, "StorageType");

rb_define_module_function(mStorageType, "id2name", storage_type_m_id2name, 1);
rb_define_module_function(mStorageType, "name2id", storage_type_m_name2id, 1);
rb_define_module_function(mStorageType, "name", storage_type_m_name, 1);
rb_define_module_function(mStorageType, "available?", storage_type_m_available_p, 1);



#define INIT_STYPE(id, name) do { \
storage_type_name_ids[id] = rb_intern(name); \
} while (0)

INIT_STYPE(kDefaultStorage, "default");
INIT_STYPE(kRowSparseStorage, "row_sparse");
INIT_STYPE(kCSRStorage, "csr");

#undef INIT_STYPE

}
Loading