Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion ClickThroughRate/WideDeepLearning/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,32 @@ python3 wdl_train_eval_test.py \
--gpu_num 1
```

OneFlow-WDL网络实现了模型并行与稀疏更新,在8卡12G TitanV的服务器上实现支持超过4亿的词表大小,而且性能没有损失与小词表性能相当,详细请参考[这篇文档](https://github.com/Oneflow-Inc/oneflow-documentation/blob/master/cn/docs/adv_examples/wide_deep.md)评测部分的内容。
## Run OneFlow-WDL with hybrid embedding
在点击率和推荐场景下,经常会碰到非常巨大的词表,如果放到GPU设备的话,可能需要很多张卡才能放的下,而使用的模型,比如Wide&Deep,所需要的计算量又不会很大,多张卡计算不饱和。CPU GPU混合的方案可以在中间做一个平衡,根据从训练集中收集的访问频率数据对词表中的所有行进行排序,将访问频率最高的行放置在设备内存中,并将其余的行放置在主机内存中。这样既用到了设备的计算和内存资源,也充分利用了主机丰富的内存资源,使得用少量的GPU设备就能够处理更大量级的词表,性能也不损失。

```
# VOCAB_SIZE=160361600
VOCAB_SIZE=16036160
HF_VOCAB_SIZE=801808
# export ONEFLOW_DEBUG_MODE=1
DATA_ROOT=/dataset/wdl_ofrecord/hf_ofrecord
python3 wdl_train_eval_with_hybrid_embd.py \
--train_data_dir $DATA_ROOT/train \
--train_data_part_num 256 \
--train_part_name_suffix_length=5 \
--eval_data_dir $DATA_ROOT/val \
--eval_data_part_num 256 \
--eval_part_name_suffix_length=5 \
--max_iter=1100 \
--loss_print_every_n_iter=100 \
--eval_interval=100000 \
--batch_size=16384 \
--wide_vocab_size=$VOCAB_SIZE \
--deep_vocab_size=$VOCAB_SIZE \
--hf_wide_vocab_size=$HF_VOCAB_SIZE \
--hf_deep_vocab_size=$HF_VOCAB_SIZE \
--num_dataloader_thread_per_gpu=8 \
--use_single_dataloader_thread \
--gpu_num 4
```
OneFlow-WDL网络实现了模型并行与稀疏更新,在8卡12G TitanV的服务器上实现支持超过4亿的词表大小,而且性能没有损失与小词表性能相当,详细请参考[这篇文档](https://github.com/Oneflow-Inc/oneflow-documentation/blob/master/cn/docs/adv_examples/wide_deep.md)评测部分的内容。
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,17 @@ def _data_loader(data_dir, data_part_num, batch_size, part_name_suffix_length=-1
devices = ['{}:0-{}'.format(i, num_dataloader_thread - 1) for i in range(FLAGS.num_nodes)]
with flow.scope.placement("cpu", devices):
if FLAGS.dataset_format == 'ofrecord':
data = _data_loader_ofrecord(data_dir, data_part_num, batch_size,
part_name_suffix_length, shuffle)
data = _data_loader_ofrecord(data_dir, data_part_num, batch_size,
part_name_suffix_length, shuffle)
elif FLAGS.dataset_format == 'onerec':
data = _data_loader_onerec(data_dir, batch_size, shuffle)
elif FLAGS.dataset_format == 'synthetic':
data = _data_loader_synthetic(batch_size)
else:
assert 0, "Please specify dataset_type as `ofrecord`, `onerec` or `synthetic`."
return flow.identity_n(data)



def _data_loader_ofrecord(data_dir, data_part_num, batch_size, part_name_suffix_length=-1,
shuffle=True):
Expand All @@ -109,10 +109,10 @@ def _blob_decoder(bn, shape, dtype=flow.int32):

def _data_loader_synthetic(batch_size):
def _blob_random(shape, dtype=flow.int32, initializer=flow.zeros_initializer(flow.int32)):
return flow.data.decode_random(shape=shape, dtype=dtype, batch_size=batch_size,
return flow.data.decode_random(shape=shape, dtype=dtype, batch_size=batch_size,
initializer=initializer)
labels = _blob_random((1,), initializer=flow.random_uniform_initializer(dtype=flow.int32))
dense_fields = _blob_random((FLAGS.num_dense_fields,), dtype=flow.float,
dense_fields = _blob_random((FLAGS.num_dense_fields,), dtype=flow.float,
initializer=flow.random_uniform_initializer())
wide_sparse_fields = _blob_random((FLAGS.num_wide_sparse_fields,))
deep_sparse_fields = _blob_random((FLAGS.num_deep_sparse_fields,))
Expand Down Expand Up @@ -189,13 +189,13 @@ def _embedding(name, ids, embedding_size, vocab_size, split_axis=0):

def _model(dense_fields, wide_sparse_fields, deep_sparse_fields):
# wide_embedding = _embedding('wide_embedding', wide_sparse_fields, 1, FLAGS.wide_vocab_size)
wide_embedding = _hybrid_embedding('wide_embedding', wide_sparse_fields, 1, FLAGS.wide_vocab_size,
wide_embedding = _hybrid_embedding('wide_embedding', wide_sparse_fields, 1, FLAGS.wide_vocab_size,
FLAGS.hf_wide_vocab_size)
wide_scores = flow.math.reduce_sum(wide_embedding, axis=[1], keepdims=True)
wide_scores = flow.parallel_cast(wide_scores, distribute=flow.distribute.split(0),
gradient_distribute=flow.distribute.broadcast())

# deep_embedding = _embedding('deep_embedding', deep_sparse_fields, FLAGS.deep_embedding_vec_size,
# deep_embedding = _embedding('deep_embedding', deep_sparse_fields, FLAGS.deep_embedding_vec_size,
# FLAGS.deep_vocab_size, split_axis=1)
deep_embedding = _hybrid_embedding('deep_embedding', deep_sparse_fields, FLAGS.deep_embedding_vec_size,
FLAGS.deep_vocab_size, FLAGS.hf_deep_vocab_size)
Expand Down Expand Up @@ -250,11 +250,11 @@ def _get_train_conf():
train_conf = flow.FunctionConfig()
train_conf.default_data_type(flow.float)
indexed_slices_ops = [
'wide_embedding',
'deep_embedding',
'hf_wide_embedding',
'wide_embedding',
'deep_embedding',
'hf_wide_embedding',
'hf_deep_embedding',
'lf_wide_embedding',
'lf_wide_embedding',
'lf_deep_embedding',
]
train_conf.indexed_slices_optimizer_conf(dict(include_op_names=dict(op_name=indexed_slices_ops)))
Expand All @@ -264,8 +264,8 @@ def _get_train_conf():
@flow.global_function('train', _get_train_conf())
def train_job():
labels, dense_fields, wide_sparse_fields, deep_sparse_fields = \
_data_loader(data_dir=FLAGS.train_data_dir, data_part_num=FLAGS.train_data_part_num,
batch_size=FLAGS.batch_size,
_data_loader(data_dir=FLAGS.train_data_dir, data_part_num=FLAGS.train_data_part_num,
batch_size=FLAGS.batch_size,
part_name_suffix_length=FLAGS.train_part_name_suffix_length, shuffle=True)
logits = _model(dense_fields, wide_sparse_fields, deep_sparse_fields)
loss = flow.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)
Expand Down Expand Up @@ -316,8 +316,6 @@ def main():
flow.config.enable_model_io_v2(True)
flow.config.enable_debug_mode(True)
flow.config.collective_boxing.nccl_enable_all_to_all(True)
check_point = flow.train.CheckPoint()
check_point.init()
for i in range(FLAGS.max_iter):
train_job().async_get(_create_train_callback(i))
if (i + 1 ) % FLAGS.eval_interval == 0:
Expand Down