-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTorchMinimumImage
More file actions
40 lines (33 loc) · 1.52 KB
/
TorchMinimumImage
File metadata and controls
40 lines (33 loc) · 1.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
# torch.set_default_tensor_type(torch.FloatTensor)
#definition of lattice_vector
#[a]
#[b]
#[c]
def minimum_image(rij, lattice_vector, inverse_lattice_vector):
#rij: (F*N*N*Dim) or (F*N*Dim) or (N*Dim)
#lattice_vector: (F*Dim*Dim) or (Dim*Dim) or (Dim*Dim)
#inverse_lattice_vector: (F*Dim*Dim) or (Dim*Dim) or (Dim*Dim)
#return: (F*Dim*Dim) or (F*N*Dim) or (1*N*Dim)
return rij - torch.matmul(lattice_vector.transpose(-2, -1).unsqueeze(-3), torch.matmul(inverse_lattice_vector.transpose(-2, -1).unsqueeze(-3), rij.transpose(-2, -1)).round()).transpose(-2, -1)
def minimum_image_rect(rij, lbox):
#rij: (F*N*Dim) or (N*Dim)
#lbox: (F* Dim) or ( Dim)
return rij - lbox.unsqueeze(-2)*(rij/lbox.unsqueeze(-2)).round()
def scale_position(rij, inverse_lattice_vector):
#rij: (N*Dim)
#inverse_lattice_vector: (Dim*Dim)
#return: (N*Dim)
return torch.matmul(inverse_lattice_vector.transpose(-2, -1), rij.transpose(-2, -1)).transpose(-2, -1)
def wrap(r, lbox):
#wrap 0~L box
return r%lbox
def unwrap(r, lattice_vector, inverse_lattice_vector):
#r: (F*N*Dim)
#lattice_vector: (Dim*Dim)
#inverse_lattice_vector: (Dim*Dim)
#return: (F*N*Dim)
dr = minimum_image(torch.diff(r, dim = 0), lattice_vector, inverse_lattice_vector)
dr = torch.cat((dr[:1]*0.0, dr), dim = 0)
unwrap_r = r[:1] + torch.cumsum(dr, dim = 0)
return unwrap_r