diff --git a/.travis.yml b/.travis.yml index d6cd3f6..922374a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,6 +13,11 @@ script: rake -f ci/Rakefile ci:run matrix: include: + - name: "Ruby 2.6.6 with MXNet 1.7.0 on Python 3.7.2" + env: + - ruby_version=2.6.6 + - python_version=3.7.2 + - mxnet_version=1.7.0 - name: "Ruby 2.6.0 with MXNet 1.3.1 on Python 3.7.2" env: - ruby_version=2.6.0 diff --git a/README.md b/README.md index 1509613..a4001a2 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,6 @@ -# MXNet +[![Build Status](https://travis-ci.org/mrkn/mxnet.rb.svg?branch=master)](https://travis-ci.org/mrkn/mxnet.rb) -Welcome to your new gem! In this directory, you'll find the files you need to be able to package up your Ruby library into a gem. Put your Ruby code in the file `lib/mxnet`. To experiment with that code, run `bin/console` for an interactive prompt. - -TODO: Delete this and the text above, and describe your gem +Welcome to the ruby mxnet bindings with access to core mxnet modules including NDArray and Gluon. The latest version tested is [1.7.0](https://mxnet.apache.org/versions/1.7.0/) ## Installation @@ -22,7 +20,10 @@ Or install it yourself as: ## Usage -TODO: Write usage instructions here +To experiment with that code, run bin/console for an interactive prompt. + +Make sure that the mxnet library files are available by including the .so in `LD_LIBRARY_PATH` environment variable. +`export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:`, alternatively put your Ruby code in the file `lib/mxnet` of this project. ## Development diff --git a/ci/Rakefile b/ci/Rakefile index 7ae4faa..607d02e 100644 --- a/ci/Rakefile +++ b/ci/Rakefile @@ -1,8 +1,8 @@ namespace :ci do def get_image_name - ruby_version = ENV['ruby_version'] || '2.5.1' + ruby_version = ENV['ruby_version'] || '2.6.6' python_version = ENV['python_version'] || '3.7.0' - mxnet_version = ENV['mxnet_version'] || '1.2.1.post1' + mxnet_version = ENV['mxnet_version'] || '1.7.0' return ['mrkn/mxnet-rb-ci', [ruby_version, python_version, mxnet_version].join('-')].join(':') end diff --git a/docker/Dockerfile.erb b/docker/Dockerfile.erb index 66493f2..f22061d 100644 --- a/docker/Dockerfile.erb +++ b/docker/Dockerfile.erb @@ -1,7 +1,7 @@ FROM rubylang/ruby:<%= ruby_version %>-bionic ARG PYTHON_VERSION=3.7.2 -ARG MXNET_VERSION=1.3.1 +ARG MXNET_VERSION=1.7.0 ENV LANG C.UTF-8 ENV DEBIAN_FRONTEND noninteractive diff --git a/docker/Rakefile b/docker/Rakefile index 15d64df..bc43ce0 100644 --- a/docker/Rakefile +++ b/docker/Rakefile @@ -40,6 +40,7 @@ namespace :docker do def run_all(task) %w[ + 2.6.6 2.6.0 2.5.3 2.4.5 @@ -52,9 +53,9 @@ namespace :docker do end task :build do - ruby_version = ENV['ruby_version'] || '2.6.0' + ruby_version = ENV['ruby_version'] || '2.6.6' python_version = ENV['python_version'] || '3.7.2' - mxnet_version = ENV['mxnet_version'] || '1.3.1' + mxnet_version = ENV['mxnet_version'] || '1.7.0' docker_build(ruby_version, python_version, mxnet_version) end @@ -65,9 +66,9 @@ namespace :docker do end task :push do - ruby_version = ENV['ruby_version'] || '2.6.0' + ruby_version = ENV['ruby_version'] || '2.6.6' python_version = ENV['python_version'] || '3.7.2' - mxnet_version = ENV['mxnet_version'] || '1.3.1' + mxnet_version = ENV['mxnet_version'] || '1.7.0' docker_push(ruby_version, python_version, mxnet_version) end diff --git a/docker/install_ruby.sh b/docker/install_ruby.sh index 8b2fb67..ff09113 100755 --- a/docker/install_ruby.sh +++ b/docker/install_ruby.sh @@ -2,10 +2,10 @@ set -ex -RUBY_VERSION=${RUBY_VERSION-2.5.1} +RUBY_VERSION=${RUBY_VERSION-2.6.6} RUBY_MAJOR=$(echo -n $RUBY_VERSION | sed -E 's/\.[0-9]+(-.*)?$//g') RUBYGEMS_VERSION=${RUBYGEMS_VERSION-2.7.7} -BUNDLER_VERSION=${BUNDLER_VERSION-1.16.4} +BUNDLER_VERSION=${BUNDLER_VERSION-2.1.2} case $RUBY_VERSION in 2.6.0) diff --git a/ext/mxnet/libmxnet.c b/ext/mxnet/libmxnet.c index d0e9101..c92ac7e 100644 --- a/ext/mxnet/libmxnet.c +++ b/ext/mxnet/libmxnet.c @@ -65,6 +65,7 @@ init_api_table(VALUE handle) INIT_API_TABLE_ENTRY(MXNDArrayGetContext); INIT_API_TABLE_ENTRY(MXNDArrayGetShape); INIT_API_TABLE_ENTRY(MXNDArrayGetDType); + INIT_API_TABLE_ENTRY(MXNDArrayGetStorageType); INIT_API_TABLE_ENTRY(MXNDArraySyncCopyFromCPU); INIT_API_TABLE_ENTRY(MXNDArraySyncCopyToCPU); INIT_API_TABLE_ENTRY(MXNDArrayAt); @@ -132,6 +133,8 @@ imperative_invoke(VALUE mod, VALUE handle, VALUE ndargs, VALUE keys, VALUE vals, ndargs = rb_convert_type(ndargs, T_ARRAY, "Array", "to_ary"); keys = rb_convert_type(keys, T_ARRAY, "Array", "to_ary"); vals = rb_convert_type(vals, T_ARRAY, "Array", "to_ary"); + + if (!NIL_P(out) && !RTEST(rb_obj_is_kind_of(out, mxnet_cNDArray))) { out = rb_convert_type(out, T_ARRAY, "Array", "to_ary"); } @@ -145,8 +148,10 @@ imperative_invoke(VALUE mod, VALUE handle, VALUE ndargs, VALUE keys, VALUE vals, num_params = (int)RARRAY_LEN(keys); keys_str = rb_str_tmp_new(sizeof(char const *)*num_params); + rb_gc_mark(keys_str); params_keys = (char const **)RSTRING_PTR(keys_str); vals_str = rb_str_tmp_new(sizeof(char const *)*num_params); + rb_gc_mark(vals_str); params_vals = (char const **)RSTRING_PTR(vals_str); for (i = 0; i < num_params; ++i) { VALUE key, val; @@ -155,6 +160,8 @@ imperative_invoke(VALUE mod, VALUE handle, VALUE ndargs, VALUE keys, VALUE vals, if (RB_TYPE_P(key, T_SYMBOL)) { key = rb_sym_to_s(key); } + + params_keys[i] = StringValueCStr(key); val = rb_String(RARRAY_AREF(vals, i)); @@ -237,8 +244,10 @@ symbol_creator(VALUE mod, VALUE handle, VALUE args, VALUE kwargs, VALUE keys, VA num_params = (int)RARRAY_LEN(keys); keys_str = rb_str_tmp_new(sizeof(char const **)*num_params); + rb_gc_mark(keys_str); params_keys = (char const **)RSTRING_PTR(keys_str); vals_str = rb_str_tmp_new(sizeof(char const **)*num_params); + rb_gc_mark(vals_str); params_vals = (char const **)RSTRING_PTR(vals_str); for (i = 0; i < num_params; ++i) { VALUE key, val; diff --git a/ext/mxnet/mxnet_internal.h b/ext/mxnet/mxnet_internal.h index 8f726ae..7205a7d 100644 --- a/ext/mxnet/mxnet_internal.h +++ b/ext/mxnet/mxnet_internal.h @@ -94,6 +94,17 @@ enum DTypeID { NUMBER_OF_DTYPE_IDS }; +enum StorageTypeID { + // dense + kDefaultStorage = 0, + // row sparse + kRowSparseStorage = 1, + // csr + kCSRStorage = 2, + NUMBER_OF_STORAGE_TYPE_IDS +}; + + struct mxnet_api_table { const char * (* MXGetLastError)(); @@ -307,6 +318,7 @@ struct mxnet_api_table { int *num_outputs, NDArrayHandle **outputs, int **out_stypes); + int (* MXNDArrayGetStorageType)(NDArrayHandle handle, int *out_storage_type); }; struct mxnet_api_table *mxnet_get_api_table(void); @@ -319,6 +331,7 @@ void *mxnet_get_handle(VALUE obj); void mxnet_set_handle(VALUE obj, VALUE handle_v); VALUE mxnet_dtype_id2name(int dtype_id); +VALUE mxnet_storage_type_id2name(int stype_id); int mxnet_dtype_name2id(VALUE dtype_name); VALUE mxnet_dtype_name(VALUE id_or_name); diff --git a/ext/mxnet/ndarray.c b/ext/mxnet/ndarray.c index 216cff6..b65e786 100644 --- a/ext/mxnet/ndarray.c +++ b/ext/mxnet/ndarray.c @@ -99,6 +99,108 @@ dtype_m_available_p(VALUE mod, VALUE dtype) return mxnet_dtype_is_available(dtype) ? Qtrue : Qfalse; } + +static size_t storage_type_sizes[NUMBER_OF_STORAGE_TYPE_IDS]; +static ID storage_type_name_ids[NUMBER_OF_STORAGE_TYPE_IDS]; + +VALUE +mxnet_storage_type_id2name(int stype_id) +{ + if (0 <= stype_id && stype_id < NUMBER_OF_STORAGE_TYPE_IDS) { + return ID2SYM(storage_type_name_ids[stype_id]); + } + + return Qnil; +} + +static VALUE +storage_type_m_id2name(VALUE mod, VALUE stype_id_v) +{ + int stype_id = NUM2INT(stype_id_v); + return mxnet_storage_type_id2name(stype_id); +} + +int +mxnet_storage_type_name2id(VALUE stype_name) +{ + ID stype_name_id; + int i; + + if (!RB_TYPE_P(stype_name, T_SYMBOL)) { + stype_name = rb_to_symbol(StringValue(stype_name)); + } + stype_name_id = SYM2ID(stype_name); + + for (i = 0; i < NUMBER_OF_STORAGE_TYPE_IDS; ++i) { + if (storage_type_name_ids[i] == stype_name_id) { + return i; + } + } + + return -1; +} + +static VALUE +storage_type_m_name2id(VALUE mod, VALUE stype_name) +{ + int stype_id; + stype_id = mxnet_storage_type_name2id(stype_name); + return stype_id == -1 ? Qnil : INT2NUM(stype_id); +} + +VALUE +mxnet_storage_type_name(VALUE id_or_name) +{ + int stype_id; + + if (RB_INTEGER_TYPE_P(id_or_name)) { + stype_id = NUM2INT(id_or_name); + } + else { + stype_id = mxnet_storage_type_name2id(id_or_name); + } + if (0 <= stype_id && stype_id < NUMBER_OF_STORAGE_TYPE_IDS) + return mxnet_storage_type_id2name(stype_id); + + return Qnil; +} + +static VALUE +storage_type_m_name(VALUE mod, VALUE id_or_name) +{ + return mxnet_storage_type_name(id_or_name); +} + +int +mxnet_storage_type_is_available(VALUE stype) +{ + int stype_id; + if (RB_INTEGER_TYPE_P(stype)) { + stype_id = NUM2INT(stype); + } + else { + if (!RB_TYPE_P(stype, T_SYMBOL)) { + VALUE str = rb_check_convert_type(stype, T_STRING, "String", "to_str"); + if (NIL_P(str)) { + rb_raise(rb_eTypeError, "Invalid type for storage type (%"PRIsVALUE")", stype); + } + stype = str; + } + stype_id = mxnet_storage_type_name2id(stype); + } + return 0 <= stype_id && stype_id < NUMBER_OF_STORAGE_TYPE_IDS; +} + +static VALUE +storage_type_m_available_p(VALUE mod, VALUE stype) +{ + return mxnet_storage_type_is_available(stype) ? Qtrue : Qfalse; +} + + + + + static void ndarray_free(void *ptr) { @@ -426,6 +528,24 @@ ndarray_get_dtype(VALUE obj) return mxnet_dtype_id2name(dtype_id); } +static int +ndarray_get_storage_type_id(VALUE obj) +{ + NDArrayHandle handle; + int stype_id; + + handle = mxnet_ndarray_get_handle(obj); + CHECK_CALL(MXNET_API(MXNDArrayGetStorageType)(handle, &stype_id)); + return stype_id; +} + +static VALUE +ndarray_get_storage_type(VALUE obj) +{ + int stype_id = ndarray_get_storage_type_id(obj); + return mxnet_storage_type_id2name(stype_id); +} + VALUE mxnet_ndarray_get_shape(VALUE obj) { @@ -741,7 +861,7 @@ ndarray_wait_to_read(VALUE obj) void mxnet_init_ndarray(void) { - VALUE cNDArray, mDType; + VALUE cNDArray, mDType, mStorageType; cNDArray = rb_const_get_at(mxnet_mMXNet, rb_intern("NDArray")); @@ -754,6 +874,7 @@ mxnet_init_ndarray(void) /* TODO: rb_define_singleton_method(cNDArray, "load_from_buffer", ndarray_s_load_from_buffer, 1); */ rb_define_method(cNDArray, "dtype", ndarray_get_dtype, 0); + rb_define_method(cNDArray, "stype", ndarray_get_storage_type, 0); rb_define_method(cNDArray, "shape", mxnet_ndarray_get_shape, 0); rb_define_method(cNDArray, "reshape", ndarray_reshape, 1); rb_define_method(cNDArray, "grad", ndarray_grad, 0); @@ -790,4 +911,25 @@ mxnet_init_ndarray(void) INIT_DTYPE(kInt64, int64_t, "int64"); #undef INIT_DTYPE + + + mStorageType = rb_define_module_under(mxnet_mMXNet, "StorageType"); + + rb_define_module_function(mStorageType, "id2name", storage_type_m_id2name, 1); + rb_define_module_function(mStorageType, "name2id", storage_type_m_name2id, 1); + rb_define_module_function(mStorageType, "name", storage_type_m_name, 1); + rb_define_module_function(mStorageType, "available?", storage_type_m_available_p, 1); + + + +#define INIT_STYPE(id, name) do { \ + storage_type_name_ids[id] = rb_intern(name); \ + } while (0) + + INIT_STYPE(kDefaultStorage, "default"); + INIT_STYPE(kRowSparseStorage, "row_sparse"); + INIT_STYPE(kCSRStorage, "csr"); + +#undef INIT_STYPE + } diff --git a/lib/mxnet/ndarray.rb b/lib/mxnet/ndarray.rb index 83fa77a..142b67e 100644 --- a/lib/mxnet/ndarray.rb +++ b/lib/mxnet/ndarray.rb @@ -2,19 +2,19 @@ module MXNet class NDArray include Enumerable - def self.ones(shape, ctx=nil, dtype=:float32, **kwargs) + def self.ones(shape, ctx: nil, dtype: :float32, **kwargs) ctx ||= Context.default dtype = Utils.dtype_id(dtype) Internal._ones(shape: shape, ctx: ctx, dtype: dtype, **kwargs) end - def self.zeros(shape, ctx=nil, dtype=:float32, **kwargs) + def self.zeros(shape, ctx: nil, dtype: :float32, **kwargs) ctx ||= Context.default dtype = Utils.dtype_id(dtype) Internal._zeros(shape: shape, ctx: ctx, dtype: dtype, **kwargs) end - def self.arange(start, stop=nil, step: 1.0, repeat: 1, ctx: nil, dtype: :float32) + def self.arange(start, stop: nil, step: 1.0, repeat: 1, ctx: nil, dtype: :float32) ctx ||= Context.default dtype = Utils.dtype_name(dtype) Internal._arange(start: start, stop: stop, step: step, repeat: repeat, dtype: dtype, ctx: ctx) @@ -761,6 +761,12 @@ def tile(*args, **kwargs) Ops.tile(self, *args, **kwargs) end + def to_stype stype=:default + raise "To convert to a CSR, the NDArray should be 2 Dimensional. Current " + + "shape is #{shape}" if shape.length != 2 and stype == :csr + Ops.cast_storage(self, stype: stype) + end + GRAD_REQ_MAP = { null: 0, write: 1, @@ -874,4 +880,5 @@ def self.NDArray(array_like, ctx: nil, dtype: :float32) end raise TypeError, "Unable convert #{array_like.class} to MXNet::NDArray" end + end diff --git a/mxnet.gemspec b/mxnet.gemspec index 5ba2fd6..dc6a18f 100644 --- a/mxnet.gemspec +++ b/mxnet.gemspec @@ -25,7 +25,7 @@ Gem::Specification.new do |spec| spec.add_dependency "fiddle" spec.add_dependency "numo-narray" - spec.add_development_dependency "bundler", ">= 1.17" + spec.add_development_dependency "bundler", ">= 2.1.2" spec.add_development_dependency "rake", ">= 12.0" spec.add_development_dependency "rake-compiler" spec.add_development_dependency "rspec", ">= 3.8" diff --git a/spec/mxnet/ndarray_spec.rb b/spec/mxnet/ndarray_spec.rb index 2020c0c..0d2fd7d 100644 --- a/spec/mxnet/ndarray_spec.rb +++ b/spec/mxnet/ndarray_spec.rb @@ -173,6 +173,53 @@ module MXNet end end + + describe '#stype and #to_stype' do + specify do + original = MXNet::NDArray.ones([1, 2]) + expect(original.stype).to eq(:default) + + casted = original.to_stype :csr + expect(casted.stype).to eq(:csr) + #ensure integrity of values + expect(MXNet::NDArray.sum((original == casted), exclude: true).to_i).to eq(original.shape.reduce :*) + + casted = casted.to_stype :default + expect(casted.stype).to eq(:default) + #ensure integrity of values + expect(MXNet::NDArray.sum((original == casted), exclude: true).to_i).to eq(original.shape.reduce :*) + + + casted = original.to_stype :row_sparse + expect(casted.stype).to eq(:row_sparse) + #ensure integrity of values + expect(MXNet::NDArray.sum((original == casted), exclude: true).to_i).to eq(original.shape.reduce :*) + + + casted = casted.to_stype :default + expect(casted.stype).to eq(:default) + #ensure integrity of values + expect(MXNet::NDArray.sum((original == casted), exclude: true).to_i).to eq(original.shape.reduce :*) + + csr_shape_lt_2 = MXNet::NDArray.empty([1]) + expect{csr_shape_lt_2.to_stype :csr }.to raise_error + + csr_shape_gt_2 = MXNet::NDArray.empty([1,2,3]) + expect{csr_shape_gt_2.to_stype :csr }.to raise_error + + + + row_sparse_shape_lt_2 = MXNet::NDArray.empty([1,2,3]) + expect{row_sparse_shape_lt_2.to_stype :row_sparse }.to_not raise_error + + row_sparse_shape_gt_2 = MXNet::NDArray.empty([1,2,3]) + expect{row_sparse_shape_gt_2.to_stype :row_sparse }.to_not raise_error + + + + end + end + describe '#ndim' do specify do x = MXNet::NDArray.empty([3, 2, 1, 4])