PyTorch now has native support for distributed tensor, might be a better way to do TP than megatron's MPU.