diff --git a/Classification/cnns/alexnet_model.py b/Classification/cnns/alexnet_model.py index c65ab7c..c04bb6d 100644 --- a/Classification/cnns/alexnet_model.py +++ b/Classification/cnns/alexnet_model.py @@ -86,11 +86,11 @@ def alexnet(images, args, need_transpose=False, training=True): data_format=data_format ) - pool1 = flow.nn.avg_pool2d(conv1, 3, 2, "VALID", data_format, name="pool1") + pool1 = flow.nn.max_pool2d(conv1, 3, 2, "VALID", data_format, name="pool1") conv2 = conv2d_layer("conv2", pool1, filters=192, kernel_size=5, data_format=data_format) - pool2 = flow.nn.avg_pool2d(conv2, 3, 2, "VALID", data_format, name="pool2") + pool2 = flow.nn.max_pool2d(conv2, 3, 2, "VALID", data_format, name="pool2") conv3 = conv2d_layer("conv3", pool2, filters=384, data_format=data_format) @@ -98,7 +98,7 @@ def alexnet(images, args, need_transpose=False, training=True): conv5 = conv2d_layer("conv5", conv4, filters=256, data_format=data_format) - pool5 = flow.nn.avg_pool2d(conv5, 3, 2, "VALID", data_format, name="pool5") + pool5 = flow.nn.max_pool2d(conv5, 3, 2, "VALID", data_format, name="pool5") if len(pool5.shape) > 2: pool5 = flow.reshape(pool5, shape=(pool5.shape[0], -1))