diff --git a/src/tensorNet.cpp b/src/tensorNet.cpp index 57f2b5c..d5029de 100644 --- a/src/tensorNet.cpp +++ b/src/tensorNet.cpp @@ -56,12 +56,12 @@ ICudaEngine* createTrtFromUFF(char* modelpath) { auto parser = createUffParser(); - parser->registerInput("enc_text", DimsCHW(1, VOC_LEN, 1)); - parser->registerInput("dec_text", DimsCHW(1, VOC_LEN, 1)); - parser->registerInput("h0_in", DimsCHW(1, DIM, 1)); - parser->registerInput("c0_in", DimsCHW(1, DIM, 1)); - parser->registerInput("h1_in", DimsCHW(1, DIM, 1)); - parser->registerInput("c1_in", DimsCHW(1, DIM, 1)); + parser->registerInput("enc_text", DimsCHW(1, VOC_LEN, 1), UffInputOrder::kNCHW); + parser->registerInput("dec_text", DimsCHW(1, VOC_LEN, 1), UffInputOrder::kNCHW); + parser->registerInput("h0_in", DimsCHW(1, DIM, 1), UffInputOrder::kNCHW); + parser->registerInput("c0_in", DimsCHW(1, DIM, 1), UffInputOrder::kNCHW); + parser->registerInput("h1_in", DimsCHW(1, DIM, 1), UffInputOrder::kNCHW); + parser->registerInput("c1_in", DimsCHW(1, DIM, 1), UffInputOrder::kNCHW); parser->registerOutput("h0_out"); parser->registerOutput("c0_out");