-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathdata_HBERT.py
More file actions
125 lines (73 loc) · 3.01 KB
/
data_HBERT.py
File metadata and controls
125 lines (73 loc) · 3.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Jul 29 16:37:10 2019
@author: peterawest
"""
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Jun 23 15:21:12 2019
@author: peterawest
"""
import numpy as np
from nltk.tokenize import word_tokenize
from nltk import ngrams
import random
import torch
def list_find(l, target):
if target is not None and target in l:
return l.index(target)
else:
return -1
def get_target_post(post, target = None):
'''If target is not none, then only return features up to target and return whether or not the
target is present'''
word_tokenized = word_tokenize(post.lower())
tokenized_string = ' '.join(word_tokenized)
target_ind = list_find(tokenized_string,target)
if target_ind == -1:
has_target = False
post_return = '[CLS] ' + post.lower()
else: # if it has target, set bool and target and beyond from word_tokenized list
has_target = True
word_tokenized = word_tokenize(tokenized_string[:target_ind])
post_return = '[CLS] ' + tokenized_string[:target_ind].lower()
return post_return, has_target
def get_post_list(post_list, tokenizer, target = None, pretokenize = False):
'''target is the target treatment feature. If not none, this just returns the features
before the target, in a tuple where the second entry is whether or not the target is present:
(features, True) or (features, False)
Assumptions: posts are in chronological order
So after the treatment is observed, no more posts are processed (all confounds are pre-treatment)
'''
posts = []
for post in post_list:
post_body, has_target = get_target_post(post,target=target)
if pretokenize:
with torch.no_grad():
posts += [tokenizer.tokenize(post_body)]
else:
posts += [post_body]
# if target is in this post, break before we include the target in grams
if has_target:
break
return posts, has_target
def set_word(gram, feature_vector, gram2ind):
'''For vocab2ind and feature_vector, return a new feature vector with this
word set to true'''
if gram in gram2ind.keys():
feature_vector[gram2ind[gram]] = 1
return feature_vector
def get_features_HBERT(Users, tokenizer, pretokenize = False):
# for i, user in enumerate(Users):
# for j,_ in enumerate(user['T0']):
# Users[i]['T0'][j]['body'] = Users[i]['T0'][j]['body'].lower()
features_by_user = []
for i, user in enumerate(Users): # do it for MH
posts, _ = get_post_list(user, tokenizer, pretokenize = pretokenize)
features_by_user += [posts] #'features':feature_vector, 'treatment':has_target, 'outcome':False}
X = []
for i,_ in enumerate(Users):
X += [features_by_user[i]]
return X