For training on custom dataset, suppose stl-10..should the input be flattened(3x96x96) as here mnist we're passing(784,) vector