From 7a002892c2c608e860d8612aab107feb0c8813dc Mon Sep 17 00:00:00 2001 From: Tjad Clark Date: Thu, 5 Nov 2020 19:20:11 +0200 Subject: [PATCH 01/18] Issue #45: GC mark param keys and values --- ext/mxnet/libmxnet.c | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/ext/mxnet/libmxnet.c b/ext/mxnet/libmxnet.c index d0e9101..bf60f81 100644 --- a/ext/mxnet/libmxnet.c +++ b/ext/mxnet/libmxnet.c @@ -132,6 +132,8 @@ imperative_invoke(VALUE mod, VALUE handle, VALUE ndargs, VALUE keys, VALUE vals, ndargs = rb_convert_type(ndargs, T_ARRAY, "Array", "to_ary"); keys = rb_convert_type(keys, T_ARRAY, "Array", "to_ary"); vals = rb_convert_type(vals, T_ARRAY, "Array", "to_ary"); + + if (!NIL_P(out) && !RTEST(rb_obj_is_kind_of(out, mxnet_cNDArray))) { out = rb_convert_type(out, T_ARRAY, "Array", "to_ary"); } @@ -145,8 +147,10 @@ imperative_invoke(VALUE mod, VALUE handle, VALUE ndargs, VALUE keys, VALUE vals, num_params = (int)RARRAY_LEN(keys); keys_str = rb_str_tmp_new(sizeof(char const *)*num_params); + rb_gc_mark(keys_str); params_keys = (char const **)RSTRING_PTR(keys_str); vals_str = rb_str_tmp_new(sizeof(char const *)*num_params); + rb_gc_mark(vals_str); params_vals = (char const **)RSTRING_PTR(vals_str); for (i = 0; i < num_params; ++i) { VALUE key, val; @@ -155,6 +159,8 @@ imperative_invoke(VALUE mod, VALUE handle, VALUE ndargs, VALUE keys, VALUE vals, if (RB_TYPE_P(key, T_SYMBOL)) { key = rb_sym_to_s(key); } + + params_keys[i] = StringValueCStr(key); val = rb_String(RARRAY_AREF(vals, i)); @@ -237,8 +243,10 @@ symbol_creator(VALUE mod, VALUE handle, VALUE args, VALUE kwargs, VALUE keys, VA num_params = (int)RARRAY_LEN(keys); keys_str = rb_str_tmp_new(sizeof(char const **)*num_params); + rb_gc_mark(keys_str); params_keys = (char const **)RSTRING_PTR(keys_str); vals_str = rb_str_tmp_new(sizeof(char const **)*num_params); + rb_gc_mark(vals_str); params_vals = (char const **)RSTRING_PTR(vals_str); for (i = 0; i < num_params; ++i) { VALUE key, val; From 30fa94d17f95721200aeda30d972916f09d30d1f Mon Sep 17 00:00:00 2001 From: Tjad Clark Date: Sun, 15 Nov 2020 17:29:05 +0200 Subject: [PATCH 02/18] Bump bundler and ruby version in Docker --- docker/install_ruby.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/install_ruby.sh b/docker/install_ruby.sh index 8b2fb67..ff09113 100755 --- a/docker/install_ruby.sh +++ b/docker/install_ruby.sh @@ -2,10 +2,10 @@ set -ex -RUBY_VERSION=${RUBY_VERSION-2.5.1} +RUBY_VERSION=${RUBY_VERSION-2.6.6} RUBY_MAJOR=$(echo -n $RUBY_VERSION | sed -E 's/\.[0-9]+(-.*)?$//g') RUBYGEMS_VERSION=${RUBYGEMS_VERSION-2.7.7} -BUNDLER_VERSION=${BUNDLER_VERSION-1.16.4} +BUNDLER_VERSION=${BUNDLER_VERSION-2.1.2} case $RUBY_VERSION in 2.6.0) From 6ccfd06c3bef1bd61258bb190d573ff99c0157ac Mon Sep 17 00:00:00 2001 From: Tjad Clark Date: Sun, 15 Nov 2020 17:32:05 +0200 Subject: [PATCH 03/18] Bump mxnet and ruby version in rake --- docker/Rakefile | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docker/Rakefile b/docker/Rakefile index 15d64df..65ab752 100644 --- a/docker/Rakefile +++ b/docker/Rakefile @@ -40,7 +40,7 @@ namespace :docker do def run_all(task) %w[ - 2.6.0 + 2.6.6 2.5.3 2.4.5 2.3.8 @@ -52,9 +52,9 @@ namespace :docker do end task :build do - ruby_version = ENV['ruby_version'] || '2.6.0' + ruby_version = ENV['ruby_version'] || '2.6.6' python_version = ENV['python_version'] || '3.7.2' - mxnet_version = ENV['mxnet_version'] || '1.3.1' + mxnet_version = ENV['mxnet_version'] || '1.7.0' docker_build(ruby_version, python_version, mxnet_version) end @@ -65,9 +65,9 @@ namespace :docker do end task :push do - ruby_version = ENV['ruby_version'] || '2.6.0' + ruby_version = ENV['ruby_version'] || '2.6.6' python_version = ENV['python_version'] || '3.7.2' - mxnet_version = ENV['mxnet_version'] || '1.3.1' + mxnet_version = ENV['mxnet_version'] || '1.7.0' docker_push(ruby_version, python_version, mxnet_version) end From f3724f3292c58bf560b1ef6ad510c051c15e3b97 Mon Sep 17 00:00:00 2001 From: Tjad Clark Date: Sun, 15 Nov 2020 17:35:34 +0200 Subject: [PATCH 04/18] bump docker to mxnet 1.7.x --- docker/Dockerfile.erb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile.erb b/docker/Dockerfile.erb index 66493f2..f22061d 100644 --- a/docker/Dockerfile.erb +++ b/docker/Dockerfile.erb @@ -1,7 +1,7 @@ FROM rubylang/ruby:<%= ruby_version %>-bionic ARG PYTHON_VERSION=3.7.2 -ARG MXNET_VERSION=1.3.1 +ARG MXNET_VERSION=1.7.0 ENV LANG C.UTF-8 ENV DEBIAN_FRONTEND noninteractive From 8476ecb2f4ddbb4e5c1d294a77e1d8129597aa63 Mon Sep 17 00:00:00 2001 From: Tjad Clark Date: Sun, 15 Nov 2020 17:39:34 +0200 Subject: [PATCH 05/18] Update travis CI for ruby 2.6.6 and mxnet 1.7.0 --- .travis.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.travis.yml b/.travis.yml index d6cd3f6..922374a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -13,6 +13,11 @@ script: rake -f ci/Rakefile ci:run matrix: include: + - name: "Ruby 2.6.6 with MXNet 1.7.0 on Python 3.7.2" + env: + - ruby_version=2.6.6 + - python_version=3.7.2 + - mxnet_version=1.7.0 - name: "Ruby 2.6.0 with MXNet 1.3.1 on Python 3.7.2" env: - ruby_version=2.6.0 From e7dd3ca2329e5181e273fb9b7668f71412825c42 Mon Sep 17 00:00:00 2001 From: Tjad Clark Date: Sun, 15 Nov 2020 17:40:29 +0200 Subject: [PATCH 06/18] Update Rakefile --- docker/Rakefile | 1 + 1 file changed, 1 insertion(+) diff --git a/docker/Rakefile b/docker/Rakefile index 65ab752..bc43ce0 100644 --- a/docker/Rakefile +++ b/docker/Rakefile @@ -41,6 +41,7 @@ namespace :docker do def run_all(task) %w[ 2.6.6 + 2.6.0 2.5.3 2.4.5 2.3.8 From 5a87a01b127a84c82b1269c6121eb1caaf0160e2 Mon Sep 17 00:00:00 2001 From: Tjad Clark Date: Sun, 15 Nov 2020 17:41:37 +0200 Subject: [PATCH 07/18] Update to bundler 2.1.2 in gemspec --- mxnet.gemspec | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mxnet.gemspec b/mxnet.gemspec index 5ba2fd6..dc6a18f 100644 --- a/mxnet.gemspec +++ b/mxnet.gemspec @@ -25,7 +25,7 @@ Gem::Specification.new do |spec| spec.add_dependency "fiddle" spec.add_dependency "numo-narray" - spec.add_development_dependency "bundler", ">= 1.17" + spec.add_development_dependency "bundler", ">= 2.1.2" spec.add_development_dependency "rake", ">= 12.0" spec.add_development_dependency "rake-compiler" spec.add_development_dependency "rspec", ">= 3.8" From 3fb4a9fece8ec2fe03afc0d17ee0166fa330aa05 Mon Sep 17 00:00:00 2001 From: Tjad Clark Date: Sun, 15 Nov 2020 17:43:05 +0200 Subject: [PATCH 08/18] Update ci rake file to mxnet 1.7.0 --- ci/Rakefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ci/Rakefile b/ci/Rakefile index 7ae4faa..607d02e 100644 --- a/ci/Rakefile +++ b/ci/Rakefile @@ -1,8 +1,8 @@ namespace :ci do def get_image_name - ruby_version = ENV['ruby_version'] || '2.5.1' + ruby_version = ENV['ruby_version'] || '2.6.6' python_version = ENV['python_version'] || '3.7.0' - mxnet_version = ENV['mxnet_version'] || '1.2.1.post1' + mxnet_version = ENV['mxnet_version'] || '1.7.0' return ['mrkn/mxnet-rb-ci', [ruby_version, python_version, mxnet_version].join('-')].join(':') end From 3319f80a190c0e6c79a90981c3eeba17764d7890 Mon Sep 17 00:00:00 2001 From: Tjad Clark Date: Sun, 15 Nov 2020 17:49:40 +0200 Subject: [PATCH 09/18] Update README.md --- README.md | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 1509613..5b475d1 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,4 @@ -# MXNet - -Welcome to your new gem! In this directory, you'll find the files you need to be able to package up your Ruby library into a gem. Put your Ruby code in the file `lib/mxnet`. To experiment with that code, run `bin/console` for an interactive prompt. - -TODO: Delete this and the text above, and describe your gem +Welcome to the ruby mxnet bindings with access to core mxnet modules including NDArray and Gluon ## Installation @@ -21,8 +17,8 @@ Or install it yourself as: $ gem install mxnet ## Usage - -TODO: Write usage instructions here +Make sure that the mxnet library files are available by including the .so in $LD_LIBRARY_PATH environment variable. +`export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:` ## Development From bea4501f0e0975c6c4a06f10e8760f470552e430 Mon Sep 17 00:00:00 2001 From: Tjad Clark Date: Sun, 15 Nov 2020 17:50:14 +0200 Subject: [PATCH 10/18] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 5b475d1..cd72afe 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -Welcome to the ruby mxnet bindings with access to core mxnet modules including NDArray and Gluon +Welcome to the ruby mxnet bindings with access to core mxnet modules including NDArray and Gluon. The latest version test is 1.7.0 ## Installation From 9555af5a221fb36082a5eba54df59fb146a64aee Mon Sep 17 00:00:00 2001 From: Tjad Clark Date: Sun, 15 Nov 2020 18:42:28 +0200 Subject: [PATCH 11/18] Add ci status to readme --- README.md | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index cd72afe..93bfbc8 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +[![Build Status](https://travis-ci.org/mrkn/mxnet.rb.svg?branch=master)](https://travis-ci.org/mrkn/mxnet.rb) + Welcome to the ruby mxnet bindings with access to core mxnet modules including NDArray and Gluon. The latest version test is 1.7.0 ## Installation @@ -17,8 +19,11 @@ Or install it yourself as: $ gem install mxnet ## Usage -Make sure that the mxnet library files are available by including the .so in $LD_LIBRARY_PATH environment variable. -`export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:` + +To experiment with that code, run bin/console for an interactive prompt. + +Make sure that the mxnet library files are available by including the .so in `LD_LIBRARY_PATH` environment variable. +`export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:`, alternatively put your Ruby code in the file `lib/mxnet` of this project. ## Development From d54dbcdadb829bfdd383d954fba4ece267c76f8f Mon Sep 17 00:00:00 2001 From: Tjad Clark Date: Sun, 15 Nov 2020 18:50:58 +0200 Subject: [PATCH 12/18] Include reference to mxnet version docs in readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 93bfbc8..a4001a2 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ [![Build Status](https://travis-ci.org/mrkn/mxnet.rb.svg?branch=master)](https://travis-ci.org/mrkn/mxnet.rb) -Welcome to the ruby mxnet bindings with access to core mxnet modules including NDArray and Gluon. The latest version test is 1.7.0 +Welcome to the ruby mxnet bindings with access to core mxnet modules including NDArray and Gluon. The latest version tested is [1.7.0](https://mxnet.apache.org/versions/1.7.0/) ## Installation From 63cd1cc4551e6dc98573d7a8c64d48eefcd76529 Mon Sep 17 00:00:00 2001 From: Tjad Clark Date: Fri, 27 Nov 2020 18:25:29 +0200 Subject: [PATCH 13/18] consistently use named parameters with NDArray initializer methods --- lib/mxnet/ndarray.rb | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/mxnet/ndarray.rb b/lib/mxnet/ndarray.rb index 83fa77a..589c41e 100644 --- a/lib/mxnet/ndarray.rb +++ b/lib/mxnet/ndarray.rb @@ -2,19 +2,19 @@ module MXNet class NDArray include Enumerable - def self.ones(shape, ctx=nil, dtype=:float32, **kwargs) + def self.ones(shape, ctx: nil, dtype: :float32, **kwargs) ctx ||= Context.default dtype = Utils.dtype_id(dtype) Internal._ones(shape: shape, ctx: ctx, dtype: dtype, **kwargs) end - def self.zeros(shape, ctx=nil, dtype=:float32, **kwargs) + def self.zeros(shape, ctx: nil, dtype: :float32, **kwargs) ctx ||= Context.default dtype = Utils.dtype_id(dtype) Internal._zeros(shape: shape, ctx: ctx, dtype: dtype, **kwargs) end - def self.arange(start, stop=nil, step: 1.0, repeat: 1, ctx: nil, dtype: :float32) + def self.arange(start, stop: nil, step: 1.0, repeat: 1, ctx: nil, dtype: :float32) ctx ||= Context.default dtype = Utils.dtype_name(dtype) Internal._arange(start: start, stop: stop, step: step, repeat: repeat, dtype: dtype, ctx: ctx) From 0e1fab708566a5429689e95d34017cd3830c2971 Mon Sep 17 00:00:00 2001 From: Tjad Clark Date: Fri, 27 Nov 2020 18:29:49 +0200 Subject: [PATCH 14/18] #51: NDArray storage type getter --- ext/mxnet/libmxnet.c | 1 + ext/mxnet/mxnet_internal.h | 13 ++++ ext/mxnet/ndarray.c | 144 ++++++++++++++++++++++++++++++++++++- 3 files changed, 157 insertions(+), 1 deletion(-) diff --git a/ext/mxnet/libmxnet.c b/ext/mxnet/libmxnet.c index bf60f81..c92ac7e 100644 --- a/ext/mxnet/libmxnet.c +++ b/ext/mxnet/libmxnet.c @@ -65,6 +65,7 @@ init_api_table(VALUE handle) INIT_API_TABLE_ENTRY(MXNDArrayGetContext); INIT_API_TABLE_ENTRY(MXNDArrayGetShape); INIT_API_TABLE_ENTRY(MXNDArrayGetDType); + INIT_API_TABLE_ENTRY(MXNDArrayGetStorageType); INIT_API_TABLE_ENTRY(MXNDArraySyncCopyFromCPU); INIT_API_TABLE_ENTRY(MXNDArraySyncCopyToCPU); INIT_API_TABLE_ENTRY(MXNDArrayAt); diff --git a/ext/mxnet/mxnet_internal.h b/ext/mxnet/mxnet_internal.h index 8f726ae..15b1324 100644 --- a/ext/mxnet/mxnet_internal.h +++ b/ext/mxnet/mxnet_internal.h @@ -94,6 +94,17 @@ enum DTypeID { NUMBER_OF_DTYPE_IDS }; +enum StorageTypeID { + // dense + kDefaultStorage = 0, + // row sparse + kRowSparseStorage = 1, + // csr + kCSRStorage = 2, + NUMBER_OF_STORAGE_TYPE_IDS +}; + + struct mxnet_api_table { const char * (* MXGetLastError)(); @@ -141,6 +152,7 @@ struct mxnet_api_table { int (* MXNDArrayGetShape)(NDArrayHandle handle, mx_uint *out_dim, const mx_uint **out_pdata); int (* MXNDArrayGetDType)(NDArrayHandle handle, int *out_dtype); + int (* MXNDArrayGetStorageType)(NDArrayHandle handle, int *out_storage_type); int (* MXNDArraySyncCopyFromCPU)(NDArrayHandle handle, const void *data, size_t size); int (* MXNDArraySyncCopyToCPU)(NDArrayHandle handle, void *data, size_t size); int (* MXNDArrayAt)(NDArrayHandle handle, mx_uint idx, NDArrayHandle *out); @@ -319,6 +331,7 @@ void *mxnet_get_handle(VALUE obj); void mxnet_set_handle(VALUE obj, VALUE handle_v); VALUE mxnet_dtype_id2name(int dtype_id); +VALUE mxnet_storage_type_id2name(int stype_id); int mxnet_dtype_name2id(VALUE dtype_name); VALUE mxnet_dtype_name(VALUE id_or_name); diff --git a/ext/mxnet/ndarray.c b/ext/mxnet/ndarray.c index 216cff6..b65e786 100644 --- a/ext/mxnet/ndarray.c +++ b/ext/mxnet/ndarray.c @@ -99,6 +99,108 @@ dtype_m_available_p(VALUE mod, VALUE dtype) return mxnet_dtype_is_available(dtype) ? Qtrue : Qfalse; } + +static size_t storage_type_sizes[NUMBER_OF_STORAGE_TYPE_IDS]; +static ID storage_type_name_ids[NUMBER_OF_STORAGE_TYPE_IDS]; + +VALUE +mxnet_storage_type_id2name(int stype_id) +{ + if (0 <= stype_id && stype_id < NUMBER_OF_STORAGE_TYPE_IDS) { + return ID2SYM(storage_type_name_ids[stype_id]); + } + + return Qnil; +} + +static VALUE +storage_type_m_id2name(VALUE mod, VALUE stype_id_v) +{ + int stype_id = NUM2INT(stype_id_v); + return mxnet_storage_type_id2name(stype_id); +} + +int +mxnet_storage_type_name2id(VALUE stype_name) +{ + ID stype_name_id; + int i; + + if (!RB_TYPE_P(stype_name, T_SYMBOL)) { + stype_name = rb_to_symbol(StringValue(stype_name)); + } + stype_name_id = SYM2ID(stype_name); + + for (i = 0; i < NUMBER_OF_STORAGE_TYPE_IDS; ++i) { + if (storage_type_name_ids[i] == stype_name_id) { + return i; + } + } + + return -1; +} + +static VALUE +storage_type_m_name2id(VALUE mod, VALUE stype_name) +{ + int stype_id; + stype_id = mxnet_storage_type_name2id(stype_name); + return stype_id == -1 ? Qnil : INT2NUM(stype_id); +} + +VALUE +mxnet_storage_type_name(VALUE id_or_name) +{ + int stype_id; + + if (RB_INTEGER_TYPE_P(id_or_name)) { + stype_id = NUM2INT(id_or_name); + } + else { + stype_id = mxnet_storage_type_name2id(id_or_name); + } + if (0 <= stype_id && stype_id < NUMBER_OF_STORAGE_TYPE_IDS) + return mxnet_storage_type_id2name(stype_id); + + return Qnil; +} + +static VALUE +storage_type_m_name(VALUE mod, VALUE id_or_name) +{ + return mxnet_storage_type_name(id_or_name); +} + +int +mxnet_storage_type_is_available(VALUE stype) +{ + int stype_id; + if (RB_INTEGER_TYPE_P(stype)) { + stype_id = NUM2INT(stype); + } + else { + if (!RB_TYPE_P(stype, T_SYMBOL)) { + VALUE str = rb_check_convert_type(stype, T_STRING, "String", "to_str"); + if (NIL_P(str)) { + rb_raise(rb_eTypeError, "Invalid type for storage type (%"PRIsVALUE")", stype); + } + stype = str; + } + stype_id = mxnet_storage_type_name2id(stype); + } + return 0 <= stype_id && stype_id < NUMBER_OF_STORAGE_TYPE_IDS; +} + +static VALUE +storage_type_m_available_p(VALUE mod, VALUE stype) +{ + return mxnet_storage_type_is_available(stype) ? Qtrue : Qfalse; +} + + + + + static void ndarray_free(void *ptr) { @@ -426,6 +528,24 @@ ndarray_get_dtype(VALUE obj) return mxnet_dtype_id2name(dtype_id); } +static int +ndarray_get_storage_type_id(VALUE obj) +{ + NDArrayHandle handle; + int stype_id; + + handle = mxnet_ndarray_get_handle(obj); + CHECK_CALL(MXNET_API(MXNDArrayGetStorageType)(handle, &stype_id)); + return stype_id; +} + +static VALUE +ndarray_get_storage_type(VALUE obj) +{ + int stype_id = ndarray_get_storage_type_id(obj); + return mxnet_storage_type_id2name(stype_id); +} + VALUE mxnet_ndarray_get_shape(VALUE obj) { @@ -741,7 +861,7 @@ ndarray_wait_to_read(VALUE obj) void mxnet_init_ndarray(void) { - VALUE cNDArray, mDType; + VALUE cNDArray, mDType, mStorageType; cNDArray = rb_const_get_at(mxnet_mMXNet, rb_intern("NDArray")); @@ -754,6 +874,7 @@ mxnet_init_ndarray(void) /* TODO: rb_define_singleton_method(cNDArray, "load_from_buffer", ndarray_s_load_from_buffer, 1); */ rb_define_method(cNDArray, "dtype", ndarray_get_dtype, 0); + rb_define_method(cNDArray, "stype", ndarray_get_storage_type, 0); rb_define_method(cNDArray, "shape", mxnet_ndarray_get_shape, 0); rb_define_method(cNDArray, "reshape", ndarray_reshape, 1); rb_define_method(cNDArray, "grad", ndarray_grad, 0); @@ -790,4 +911,25 @@ mxnet_init_ndarray(void) INIT_DTYPE(kInt64, int64_t, "int64"); #undef INIT_DTYPE + + + mStorageType = rb_define_module_under(mxnet_mMXNet, "StorageType"); + + rb_define_module_function(mStorageType, "id2name", storage_type_m_id2name, 1); + rb_define_module_function(mStorageType, "name2id", storage_type_m_name2id, 1); + rb_define_module_function(mStorageType, "name", storage_type_m_name, 1); + rb_define_module_function(mStorageType, "available?", storage_type_m_available_p, 1); + + + +#define INIT_STYPE(id, name) do { \ + storage_type_name_ids[id] = rb_intern(name); \ + } while (0) + + INIT_STYPE(kDefaultStorage, "default"); + INIT_STYPE(kRowSparseStorage, "row_sparse"); + INIT_STYPE(kCSRStorage, "csr"); + +#undef INIT_STYPE + } From 5a1cac7158f02d5e9ec2f92e2fd66e64df62d90a Mon Sep 17 00:00:00 2001 From: Tjad Clark Date: Sat, 5 Dec 2020 20:07:12 +0200 Subject: [PATCH 15/18] #51: Correct order of functions (mxnet_api_table stuct currently order sensitive) --- ext/mxnet/mxnet_internal.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/mxnet/mxnet_internal.h b/ext/mxnet/mxnet_internal.h index 15b1324..7205a7d 100644 --- a/ext/mxnet/mxnet_internal.h +++ b/ext/mxnet/mxnet_internal.h @@ -152,7 +152,6 @@ struct mxnet_api_table { int (* MXNDArrayGetShape)(NDArrayHandle handle, mx_uint *out_dim, const mx_uint **out_pdata); int (* MXNDArrayGetDType)(NDArrayHandle handle, int *out_dtype); - int (* MXNDArrayGetStorageType)(NDArrayHandle handle, int *out_storage_type); int (* MXNDArraySyncCopyFromCPU)(NDArrayHandle handle, const void *data, size_t size); int (* MXNDArraySyncCopyToCPU)(NDArrayHandle handle, void *data, size_t size); int (* MXNDArrayAt)(NDArrayHandle handle, mx_uint idx, NDArrayHandle *out); @@ -319,6 +318,7 @@ struct mxnet_api_table { int *num_outputs, NDArrayHandle **outputs, int **out_stypes); + int (* MXNDArrayGetStorageType)(NDArrayHandle handle, int *out_storage_type); }; struct mxnet_api_table *mxnet_get_api_table(void); From c97430584cc523cc6203801a093077ac1f4e6822 Mon Sep 17 00:00:00 2001 From: Tjad Clark Date: Sun, 6 Dec 2020 18:55:37 +0200 Subject: [PATCH 16/18] #51: to_stype implementation --- lib/mxnet/ndarray.rb | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lib/mxnet/ndarray.rb b/lib/mxnet/ndarray.rb index 589c41e..0b4e469 100644 --- a/lib/mxnet/ndarray.rb +++ b/lib/mxnet/ndarray.rb @@ -761,6 +761,12 @@ def tile(*args, **kwargs) Ops.tile(self, *args, **kwargs) end + def to_stype stype=:default + raise "To convert to a CSR, the NDArray should be 2 Dimensional. Current " + + "shape is #{shape}" if shape.length != 2 + Ops.cast_storage(self, stype: stype) + end + GRAD_REQ_MAP = { null: 0, write: 1, @@ -874,4 +880,5 @@ def self.NDArray(array_like, ctx: nil, dtype: :float32) end raise TypeError, "Unable convert #{array_like.class} to MXNet::NDArray" end + end From 9982025faf40017a849781c6ac28f73fae58b29b Mon Sep 17 00:00:00 2001 From: Tjad Clark Date: Sun, 6 Dec 2020 18:56:06 +0200 Subject: [PATCH 17/18] #51: stype tests --- spec/mxnet/ndarray_spec.rb | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/spec/mxnet/ndarray_spec.rb b/spec/mxnet/ndarray_spec.rb index 2020c0c..871b3ad 100644 --- a/spec/mxnet/ndarray_spec.rb +++ b/spec/mxnet/ndarray_spec.rb @@ -173,6 +173,30 @@ module MXNet end end + + describe '#stype and #to_stype' do + specify do + a = MXNet::NDArray.empty([1, 2]) + expect(a.stype).to eq(:default) + a = a.to_stype :csr + expect(a.stype).to eq(:csr) + a = a.to_stype :default + expect(a.stype).to eq(:default) + a = a.to_stype :row_sparse + expect(a.stype).to eq(:row_sparse) + a = a.to_stype :default + expect(a.stype).to eq(:default) + + a = Mxnet::NDArray.empty([1]) + expect{a.to_stype :csr }.to raise_error + + a = Mxnet::NDArray.empty([1,2,3]) + expect{a.to_stype :csr }.to raise_error + + + end + end + describe '#ndim' do specify do x = MXNet::NDArray.empty([3, 2, 1, 4]) From decb5f32f57b6eb7ad58c6ed856f7343202bf243 Mon Sep 17 00:00:00 2001 From: Tjad Clark Date: Sun, 6 Dec 2020 19:30:54 +0200 Subject: [PATCH 18/18] #51: Correct logic, add test clarity and completion --- lib/mxnet/ndarray.rb | 2 +- spec/mxnet/ndarray_spec.rb | 53 +++++++++++++++++++++++++++----------- 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/lib/mxnet/ndarray.rb b/lib/mxnet/ndarray.rb index 0b4e469..142b67e 100644 --- a/lib/mxnet/ndarray.rb +++ b/lib/mxnet/ndarray.rb @@ -763,7 +763,7 @@ def tile(*args, **kwargs) def to_stype stype=:default raise "To convert to a CSR, the NDArray should be 2 Dimensional. Current " + - "shape is #{shape}" if shape.length != 2 + "shape is #{shape}" if shape.length != 2 and stype == :csr Ops.cast_storage(self, stype: stype) end diff --git a/spec/mxnet/ndarray_spec.rb b/spec/mxnet/ndarray_spec.rb index 871b3ad..0d2fd7d 100644 --- a/spec/mxnet/ndarray_spec.rb +++ b/spec/mxnet/ndarray_spec.rb @@ -176,22 +176,45 @@ module MXNet describe '#stype and #to_stype' do specify do - a = MXNet::NDArray.empty([1, 2]) - expect(a.stype).to eq(:default) - a = a.to_stype :csr - expect(a.stype).to eq(:csr) - a = a.to_stype :default - expect(a.stype).to eq(:default) - a = a.to_stype :row_sparse - expect(a.stype).to eq(:row_sparse) - a = a.to_stype :default - expect(a.stype).to eq(:default) - - a = Mxnet::NDArray.empty([1]) - expect{a.to_stype :csr }.to raise_error + original = MXNet::NDArray.ones([1, 2]) + expect(original.stype).to eq(:default) + + casted = original.to_stype :csr + expect(casted.stype).to eq(:csr) + #ensure integrity of values + expect(MXNet::NDArray.sum((original == casted), exclude: true).to_i).to eq(original.shape.reduce :*) + + casted = casted.to_stype :default + expect(casted.stype).to eq(:default) + #ensure integrity of values + expect(MXNet::NDArray.sum((original == casted), exclude: true).to_i).to eq(original.shape.reduce :*) + + + casted = original.to_stype :row_sparse + expect(casted.stype).to eq(:row_sparse) + #ensure integrity of values + expect(MXNet::NDArray.sum((original == casted), exclude: true).to_i).to eq(original.shape.reduce :*) + + + casted = casted.to_stype :default + expect(casted.stype).to eq(:default) + #ensure integrity of values + expect(MXNet::NDArray.sum((original == casted), exclude: true).to_i).to eq(original.shape.reduce :*) - a = Mxnet::NDArray.empty([1,2,3]) - expect{a.to_stype :csr }.to raise_error + csr_shape_lt_2 = MXNet::NDArray.empty([1]) + expect{csr_shape_lt_2.to_stype :csr }.to raise_error + + csr_shape_gt_2 = MXNet::NDArray.empty([1,2,3]) + expect{csr_shape_gt_2.to_stype :csr }.to raise_error + + + + row_sparse_shape_lt_2 = MXNet::NDArray.empty([1,2,3]) + expect{row_sparse_shape_lt_2.to_stype :row_sparse }.to_not raise_error + + row_sparse_shape_gt_2 = MXNet::NDArray.empty([1,2,3]) + expect{row_sparse_shape_gt_2.to_stype :row_sparse }.to_not raise_error + end