Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions test/TensorTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <gtest/gtest.h>
#include <torch/all.h>

#include <sstream>
#include <vector>

namespace at {
Expand Down Expand Up @@ -170,5 +171,113 @@ TEST_F(TensorTest, Transpose) {
EXPECT_EQ(transposed.sizes()[2], 2);
}

// 测试 layout
TEST_F(TensorTest, Layout) {
// 默认创建的张量应该是 strided 布局
c10::Layout layout = tensor.layout();
EXPECT_EQ(layout, c10::Layout::Strided);
}

// 测试 layout 常量别名
TEST_F(TensorTest, LayoutConstants) {
// 测试 c10 命名空间下的常量别名
EXPECT_EQ(c10::kStrided, c10::Layout::Strided);
EXPECT_EQ(c10::kSparse, c10::Layout::Sparse);
EXPECT_EQ(c10::kSparseCsr, c10::Layout::SparseCsr);
EXPECT_EQ(c10::kSparseCsc, c10::Layout::SparseCsc);
EXPECT_EQ(c10::kSparseBsr, c10::Layout::SparseBsr);
EXPECT_EQ(c10::kSparseBsc, c10::Layout::SparseBsc);
EXPECT_EQ(c10::kMkldnn, c10::Layout::Mkldnn);
EXPECT_EQ(c10::kJagged, c10::Layout::Jagged);
}

// 测试 at 命名空间下的 layout 常量
TEST_F(TensorTest, LayoutConstantsInAtNamespace) {
EXPECT_EQ(at::kStrided, c10::Layout::Strided);
EXPECT_EQ(at::kSparse, c10::Layout::Sparse);
EXPECT_EQ(at::kSparseCsr, c10::Layout::SparseCsr);
EXPECT_EQ(at::kSparseCsc, c10::Layout::SparseCsc);
EXPECT_EQ(at::kSparseBsr, c10::Layout::SparseBsr);
EXPECT_EQ(at::kSparseBsc, c10::Layout::SparseBsc);
EXPECT_EQ(at::kMkldnn, c10::Layout::Mkldnn);
EXPECT_EQ(at::kJagged, c10::Layout::Jagged);
}

// 测试 torch 命名空间下的 layout 常量
TEST_F(TensorTest, LayoutConstantsInTorchNamespace) {
EXPECT_EQ(torch::kStrided, c10::Layout::Strided);
EXPECT_EQ(torch::kSparse, c10::Layout::Sparse);
EXPECT_EQ(torch::kSparseCsr, c10::Layout::SparseCsr);
EXPECT_EQ(torch::kSparseCsc, c10::Layout::SparseCsc);
EXPECT_EQ(torch::kSparseBsr, c10::Layout::SparseBsr);
EXPECT_EQ(torch::kSparseBsc, c10::Layout::SparseBsc);
EXPECT_EQ(torch::kMkldnn, c10::Layout::Mkldnn);
EXPECT_EQ(torch::kJagged, c10::Layout::Jagged);
}

// 测试 layout 枚举值
TEST_F(TensorTest, LayoutEnumValues) {
// 测试 Layout 枚举的底层值
EXPECT_EQ(static_cast<int8_t>(c10::Layout::Strided), 0);
EXPECT_EQ(static_cast<int8_t>(c10::Layout::Sparse), 1);
EXPECT_EQ(static_cast<int8_t>(c10::Layout::SparseCsr), 2);
EXPECT_EQ(static_cast<int8_t>(c10::Layout::Mkldnn), 3);
EXPECT_EQ(static_cast<int8_t>(c10::Layout::SparseCsc), 4);
EXPECT_EQ(static_cast<int8_t>(c10::Layout::SparseBsr), 5);
EXPECT_EQ(static_cast<int8_t>(c10::Layout::SparseBsc), 6);
EXPECT_EQ(static_cast<int8_t>(c10::Layout::Jagged), 7);
EXPECT_EQ(static_cast<int8_t>(c10::Layout::NumOptions), 8);
}

Comment on lines +218 to +231
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing the underlying enum values (casting to int8_t) creates a brittle test that depends on the internal representation of the enum. If the enum ordering changes in PyTorch/ATen, this test will fail. Consider removing this test or documenting why the specific enum values must remain stable.

Suggested change
// 测试 layout 枚举值
TEST_F(TensorTest, LayoutEnumValues) {
// 测试 Layout 枚举的底层值
EXPECT_EQ(static_cast<int8_t>(c10::Layout::Strided), 0);
EXPECT_EQ(static_cast<int8_t>(c10::Layout::Sparse), 1);
EXPECT_EQ(static_cast<int8_t>(c10::Layout::SparseCsr), 2);
EXPECT_EQ(static_cast<int8_t>(c10::Layout::Mkldnn), 3);
EXPECT_EQ(static_cast<int8_t>(c10::Layout::SparseCsc), 4);
EXPECT_EQ(static_cast<int8_t>(c10::Layout::SparseBsr), 5);
EXPECT_EQ(static_cast<int8_t>(c10::Layout::SparseBsc), 6);
EXPECT_EQ(static_cast<int8_t>(c10::Layout::Jagged), 7);
EXPECT_EQ(static_cast<int8_t>(c10::Layout::NumOptions), 8);
}

Copilot uses AI. Check for mistakes.
// 测试 layout 输出流操作符
TEST_F(TensorTest, LayoutOutputStream) {
std::ostringstream oss;

oss.str("");
oss << c10::Layout::Strided;
EXPECT_EQ(oss.str(), "Strided");

oss.str("");
oss << c10::Layout::Sparse;
EXPECT_EQ(oss.str(), "Sparse");

oss.str("");
oss << c10::Layout::SparseCsr;
EXPECT_EQ(oss.str(), "SparseCsr");

oss.str("");
oss << c10::Layout::SparseCsc;
EXPECT_EQ(oss.str(), "SparseCsc");

oss.str("");
oss << c10::Layout::SparseBsr;
EXPECT_EQ(oss.str(), "SparseBsr");

oss.str("");
oss << c10::Layout::SparseBsc;
EXPECT_EQ(oss.str(), "SparseBsc");

oss.str("");
oss << c10::Layout::Mkldnn;
EXPECT_EQ(oss.str(), "Mkldnn");

oss.str("");
oss << c10::Layout::Jagged;
EXPECT_EQ(oss.str(), "Jagged");
Comment on lines +235 to +266
Copy link

Copilot AI Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test LayoutOutputStream contains repetitive code that clears and writes to the same ostringstream. This reduces maintainability. Consider using a helper function or restructuring the test to reduce duplication.

Suggested change
oss.str("");
oss << c10::Layout::Strided;
EXPECT_EQ(oss.str(), "Strided");
oss.str("");
oss << c10::Layout::Sparse;
EXPECT_EQ(oss.str(), "Sparse");
oss.str("");
oss << c10::Layout::SparseCsr;
EXPECT_EQ(oss.str(), "SparseCsr");
oss.str("");
oss << c10::Layout::SparseCsc;
EXPECT_EQ(oss.str(), "SparseCsc");
oss.str("");
oss << c10::Layout::SparseBsr;
EXPECT_EQ(oss.str(), "SparseBsr");
oss.str("");
oss << c10::Layout::SparseBsc;
EXPECT_EQ(oss.str(), "SparseBsc");
oss.str("");
oss << c10::Layout::Mkldnn;
EXPECT_EQ(oss.str(), "Mkldnn");
oss.str("");
oss << c10::Layout::Jagged;
EXPECT_EQ(oss.str(), "Jagged");
auto checkLayoutToString = [&oss](c10::Layout layout, const std::string& expected) {
oss.str("");
oss.clear();
oss << layout;
EXPECT_EQ(oss.str(), expected);
};
checkLayoutToString(c10::Layout::Strided, "Strided");
checkLayoutToString(c10::Layout::Sparse, "Sparse");
checkLayoutToString(c10::Layout::SparseCsr, "SparseCsr");
checkLayoutToString(c10::Layout::SparseCsc, "SparseCsc");
checkLayoutToString(c10::Layout::SparseBsr, "SparseBsr");
checkLayoutToString(c10::Layout::SparseBsc, "SparseBsc");
checkLayoutToString(c10::Layout::Mkldnn, "Mkldnn");
checkLayoutToString(c10::Layout::Jagged, "Jagged");

Copilot uses AI. Check for mistakes.
}

// 测试使用 kStrided 常量与 tensor.layout() 比较
TEST_F(TensorTest, LayoutWithConstant) {
// 使用常量别名进行比较
EXPECT_EQ(tensor.layout(), at::kStrided);
EXPECT_EQ(tensor.layout(), torch::kStrided);
EXPECT_EQ(tensor.layout(), c10::kStrided);

// 确保不是其他布局类型
EXPECT_NE(tensor.layout(), at::kSparse);
EXPECT_NE(tensor.layout(), at::kSparseCsr);
EXPECT_NE(tensor.layout(), at::kMkldnn);
}

} // namespace test
} // namespace at