I wonder, how to train only the refinement module? Cause I don't get familiar with Pytorch. Thank you.