-
Notifications
You must be signed in to change notification settings - Fork 294
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
修改RandU和RandN函数并添加相关单元测试 #43
Conversation
TEST(test_utensor, copy_construct1) {
using namespace kuiper_infer;
Tensor<uint8_t> f1(3, 224, 224);
f1.RandU(2, 4);
f1.Show();
Tensor<uint8_t> f2(f1);
ASSERT_EQ(f2.channels(), 3);
ASSERT_EQ(f2.rows(), 224);
ASSERT_EQ(f2.cols(), 224);
ASSERT_TRUE(arma::approx_equal(f2.data(), f1.data(), "absdiff", 1e-4));
} 第三點"RandN函数和RandU函数中,其代码实现无法支持uint8数据类型的使用", 我不太清楚你說的不支持是什麼意思,我這裏按照這樣的調用是可以生成2-4之間的隨機數分佈的。 |
好像是windows才出現這個問題 |
并不是windows才会出现这样的问题。我对模版类的特化并没有太深入的思考。我认为目前的在.hpp文件中实现函数功能的做法,其函数在特化时是动态进行的,而我提交的PR中,其函数特化是在编译时期完成的。因此出现了RandU函数和RandN的BUG。先从RandN说起,如果在目前的代码中,添加如下测试: TEST(test_utensor, randn) { ASSERT_TRUE(arma::approx_equal(f2.data(), f1.data(), "absdiff", 1e-4)); 在编译时就会出现bug,因为RandN函数中使用了的正态分布函数normal_distribution只支持浮点及其衍生类型,可以参考https://en.cppreference.com/w/cpp/numeric/random/normal_distribution,其源代码中也加入了static_assert(std::is_floating_point<_RealType>作静态验证,因此当Tensor<uint_8>调用该函数时,就会出现错误,由于之前没有RandN的相关测试,因此这个函数并没有被实际编译。 基于上述原因,必须在编译前特化相关类型的实现,从而使编译器检查出错误。而这一点又导致了RandU函数的错误,RandU中有两个静态类型断言,无论是什么类型,上述断言总有一个是错误的,因此无法通过编译,所以我将RandU函数对不同数据类型给出了不同的实现。 如果把静态断言修改为普通的类型判断应该更好的解决这个问题,不过我还没有进行测试。因为这个地方的目的就是判断数据类型T是一个浮点还是整数。而使用断言则会导致程序异常终止。 |
有以下的问题: |
This reverts commit 81c7c93.
我已经撤回了test_*.cpp中的修改,这是编辑器自动进行的格式化保存,以后我会注意。另外关于RandN的特化问题,原函数的实现使用的正态分布函数并不支持int,uint_8类型,因此我还没有太好的想法用于生成整数正态分布随机数,如果想到,我会进行编写及功能测试。 |
你可以看这里,出自github的自动编译,windows不支持std::uniform_int_distribution<uint8_t>。 |
这个很奇怪啊,为什么标准库函数在不同的平台下会有差异呢。会不会是GitHub平台的windows编译器选择的有问题?我查找到一个VS2022的文档,https://learn.microsoft.com/zh-cn/cpp/standard-library/uniform-int-distribution-class?view=msvc-170 里面明确指示了是支持uint类型的。 |
是的,还有你看这个项目的gitub action在windows上的编译已经挂了好久了,但是我本地的windows环境又还原不出来,一直在排查中。 |
现在只能用这个临时方法应付一下 template <>
void Tensor<std::uint8_t>::RandU(std::uint8_t min, std::uint8_t max) {
CHECK(!this->data_.empty());
std::random_device rd;
std::mt19937 mt(rd());
#ifdef _MSC_VER
std::uniform_int_distribution<int> dist(min, max);
uint8_t max_value = std::numeric_limits<uint8_t>::max();
for (uint32_t i = 0; i < this->size(); ++i) {
this->index(i) = dist(mt) % max_value;
}
#else
...
... |
注:本次提交的PR可能包含不同的主题,我并没有太多向公共仓库提交PR的经验,不知道把所有commit合并在一起提交的方式是否合适,如有不妥,欢迎指正。本次的修改主要包括: