@@ -79,11 +79,25 @@ def do_POST(self):
7979 self .wfile .write (bytes (f"Unhandled path: { self .path } " , "utf-8" ))
8080
8181class TestKaggleModuleResolver (unittest .TestCase ):
82- def test_kaggle_resolver_succeeds (self ):
82+ def test_kaggle_resolver_long_url_succeeds (self ):
83+ model_url = "https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2"
8384 with create_test_server (KaggleJwtHandler ) as addr :
8485 test_inputs = tf .ones ([1 ,4 ])
85- layer = hub .KerasLayer ("https://kaggle.com/models/foo/foomodule/frameworks/TensorFlow2/variations/barvar/versions/2" )
86+ layer = hub .KerasLayer (model_url )
8687 self .assertEqual ([1 , 1 ], layer (test_inputs ).shape )
88+ # Delete the files that were created in KaggleJwtHandler's do_POST method
89+ os .unlink (os .path .join (MOUNT_PATH , "foomodule/tensorflow2/barvar/2" ))
90+ os .rmdir (os .path .dirname (os .path .join (MOUNT_PATH , "foomodule/tensorflow2/barvar/2" )))
91+
92+ def test_kaggle_resolver_short_url_succeeds (self ):
93+ model_url = "https://kaggle.com/models/foo/foomodule/TensorFlow2/barvar/2"
94+ with create_test_server (KaggleJwtHandler ) as addr :
95+ test_inputs = tf .ones ([1 ,4 ])
96+ layer = hub .KerasLayer (model_url )
97+ self .assertEqual ([1 , 1 ], layer (test_inputs ).shape )
98+ # Delete the files that were created in KaggleJwtHandler's do_POST method
99+ os .unlink (os .path .join (MOUNT_PATH , "foomodule/tensorflow2/barvar/2" ))
100+ os .rmdir (os .path .dirname (os .path .join (MOUNT_PATH , "foomodule/tensorflow2/barvar/2" )))
87101
88102 def test_kaggle_resolver_not_attached_throws (self ):
89103 with create_test_server (KaggleJwtHandler ) as addr :
0 commit comments