Skip to content

Commit 3e41893

Browse files
authored
[Test] Use WARNING_QUIT instead of exit, Add and refactor unittests for ParaGlobal (#6781)
* Add and refactor unittests for ParaGlobal * Move a little comma * Use standard setting style for gtest_death_test_style * Add nproc and my_rank for ParaGlobalDeathTest SetUp * Put InitPools in ParaGlobal * Put InitPools in ParaGlobalDeathTest, switch off CaptureStdout * Put InitPools in ParaGlobal * Fix DeathTest so that it only runs on rank 0 * Put InitPools in ParaGlobalDeathTest * Remove cerr, use warning_quit only * Use cout for proc/kpar prompt * Fix Test to run for cout * Add descriptions of test block
1 parent ce25b0d commit 3e41893

File tree

2 files changed

+184
-20
lines changed

2 files changed

+184
-20
lines changed

source/source_base/parallel_global.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -328,9 +328,10 @@ void Parallel_Global::divide_pools(const int& NPROC,
328328
// and MY_BNDGROUP will be the same as well.
329329
if(BNDPAR > 1 && NPROC %(BNDPAR * KPAR) != 0)
330330
{
331-
std::cout << "Error: When BNDPAR = " << BNDPAR << " > 1, number of processes (" << NPROC << ") must be divisible by the number of groups ("
332-
<< BNDPAR * KPAR << ")." << std::endl;
333-
exit(1);
331+
std::cout << "Error: When BNDPAR = " << BNDPAR << " > 1, number of processes (" << NPROC
332+
<< ") must be divisible by the number of groups (" << BNDPAR * KPAR << ")." << std::endl;
333+
ModuleBase::WARNING_QUIT("ParallelGlobal::divide_pools",
334+
"When BNDPAR > 1, number of processes NPROC must be divisible by the number of groups BNDPAR * KPAR.");
334335
}
335336
// k-point parallelization
336337
MPICommGroup kpar_group(MPI_COMM_WORLD);

source/source_base/test_parallel/parallel_global_test.cpp

Lines changed: 180 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
#include <complex>
99
#include <cstring>
1010
#include <string>
11+
#include <unistd.h>
1112

12-
#include "source_base/tool_quit.h"
13+
#include "source_base/global_variable.h"
1314

1415
/************************************************
1516
* unit test of functions in parallel_global.cpp
@@ -66,6 +67,7 @@ class MPIContext
6667
int _size;
6768
};
6869

70+
// --- Normal Test ---
6971
class ParaGlobal : public ::testing::Test
7072
{
7173
protected:
@@ -79,6 +81,7 @@ class ParaGlobal : public ::testing::Test
7981
}
8082
};
8183

84+
8285
TEST_F(ParaGlobal, SplitGrid)
8386
{
8487
// NPROC is set to 4 in parallel_global_test.sh
@@ -162,14 +165,126 @@ TEST_F(ParaGlobal, MyProd)
162165
EXPECT_EQ(inout[1], std::complex<double>(-3.0, -3.0));
163166
}
164167

165-
TEST_F(ParaGlobal, InitPools)
168+
169+
170+
TEST_F(ParaGlobal, DivideMPIPools)
171+
{
172+
this->nproc = 12;
173+
mpi.kpar = 3;
174+
this->my_rank = 5;
175+
Parallel_Global::divide_mpi_groups(this->nproc,
176+
mpi.kpar,
177+
this->my_rank,
178+
mpi.nproc_in_pool,
179+
mpi.my_pool,
180+
mpi.rank_in_pool);
181+
EXPECT_EQ(mpi.nproc_in_pool, 4);
182+
EXPECT_EQ(mpi.my_pool, 1);
183+
EXPECT_EQ(mpi.rank_in_pool, 1);
184+
}
185+
186+
187+
class FakeMPIContext
188+
{
189+
public:
190+
FakeMPIContext()
191+
{
192+
_rank = 0;
193+
_size = 1;
194+
}
195+
196+
int GetRank() const
197+
{
198+
return _rank;
199+
}
200+
int GetSize() const
201+
{
202+
return _size;
203+
}
204+
205+
int drank;
206+
int dsize;
207+
int dcolor;
208+
209+
int grank;
210+
int gsize;
211+
212+
int kpar;
213+
int nproc_in_pool;
214+
int my_pool;
215+
int rank_in_pool;
216+
217+
int nstogroup;
218+
int MY_BNDGROUP;
219+
int rank_in_stogroup;
220+
int nproc_in_stogroup;
221+
222+
private:
223+
int _rank;
224+
int _size;
225+
};
226+
227+
// --- DeathTest: Single thread ---
228+
// Since these precondition checks cause the processes to die, we call such tests death tests.
229+
// convention of naming the test suite: *DeathTest
230+
// Death tests should be run in a single-threaded context.
231+
// Such DeathTest will be run before all other tests.
232+
class ParaGlobalDeathTest : public ::testing::Test
233+
{
234+
protected:
235+
FakeMPIContext mpi;
236+
int nproc;
237+
int my_rank;
238+
int real_rank;
239+
240+
// DeathTest SetUp:
241+
// Init variable, single thread
242+
void SetUp() override
243+
{
244+
int is_init = 0;
245+
MPI_Initialized(&is_init);
246+
if (is_init) {
247+
MPI_Comm_rank(MPI_COMM_WORLD, &real_rank);
248+
} else {
249+
real_rank = 0;
250+
}
251+
252+
if (real_rank != 0) return;
253+
254+
nproc = mpi.GetSize();
255+
my_rank = mpi.GetRank();
256+
257+
// init log file needed by WARNING_QUIT
258+
GlobalV::ofs_warning.open("warning.log");
259+
260+
261+
}
262+
263+
// clean log file
264+
void TearDown() override
265+
{
266+
if (real_rank != 0) return;
267+
268+
GlobalV::ofs_warning.close();
269+
remove("warning.log");
270+
}
271+
};
272+
273+
TEST_F(ParaGlobalDeathTest, InitPools)
166274
{
275+
if (real_rank != 0) return;
167276
nproc = 12;
168277
mpi.kpar = 3;
169278
mpi.nstogroup = 3;
170279
my_rank = 5;
171-
testing::internal::CaptureStdout();
172-
EXPECT_EXIT(Parallel_Global::init_pools(nproc,
280+
EXPECT_EXIT(
281+
// This gtest Macro expect that a given `statement` causes the program to exit, with an
282+
// integer exit status that satisfies `predicate`(Here ::testing::ExitedWithCode(1)),
283+
// and emitting error output that matches `matcher`(Here "Error").
284+
{
285+
// redirect stdout to stderr to capture WARNING_QUIT output
286+
dup2(STDERR_FILENO, STDOUT_FILENO);
287+
Parallel_Global::init_pools(nproc,
173288
my_rank,
174289
mpi.nstogroup,
175290
mpi.kpar,
@@ -178,35 +293,83 @@ TEST_F(ParaGlobal, InitPools)
178293
mpi.MY_BNDGROUP,
179294
mpi.nproc_in_pool,
180295
mpi.rank_in_pool,
181-
mpi.my_pool), ::testing::ExitedWithCode(1), "");
182-
std::string output = testing::internal::GetCapturedStdout();
183-
EXPECT_THAT(output, testing::HasSubstr("Error:"));
296+
mpi.my_pool);
297+
},
298+
::testing::ExitedWithCode(1),
299+
"Error");
184300
}
185301

186-
187-
TEST_F(ParaGlobal, DivideMPIPools)
302+
TEST_F(ParaGlobalDeathTest, DivideMPIPoolsNgEqZero)
188303
{
304+
if (real_rank != 0) return;
305+
// test for num_groups == 0,
306+
// Num_group Equals 0
307+
// WARNING_QUIT
189308
this->nproc = 12;
190-
mpi.kpar = 3;
191-
this->my_rank = 5;
192-
Parallel_Global::divide_mpi_groups(this->nproc,
309+
mpi.kpar = 0;
310+
EXPECT_EXIT(
311+
{
312+
// redirect stdout to stderr to capture WARNING_QUIT output
313+
dup2(STDERR_FILENO, STDOUT_FILENO);
314+
Parallel_Global::divide_mpi_groups(this->nproc,
193315
mpi.kpar,
194316
this->my_rank,
195317
mpi.nproc_in_pool,
196318
mpi.my_pool,
197319
mpi.rank_in_pool);
198-
EXPECT_EQ(mpi.nproc_in_pool, 4);
199-
EXPECT_EQ(mpi.my_pool, 1);
200-
EXPECT_EQ(mpi.rank_in_pool, 1);
320+
},
321+
::testing::ExitedWithCode(1),
322+
"Number of groups must be greater than 0."
323+
);
324+
}
325+
326+
TEST_F(ParaGlobalDeathTest, DivideMPIPoolsNgGtProc)
327+
{
328+
if (real_rank != 0) return;
329+
// test for procs < num_groups
330+
// Num_group GreaterThan Processors
331+
// WARNING_QUIT
332+
this->nproc = 12;
333+
mpi.kpar = 24;
334+
this->my_rank = 5;
335+
EXPECT_EXIT(
336+
{
337+
// redirect stdout to stderr to capture WARNING_QUIT output
338+
dup2(STDERR_FILENO, STDOUT_FILENO);
339+
Parallel_Global::divide_mpi_groups(this->nproc,
340+
mpi.kpar,
341+
this->my_rank,
342+
mpi.nproc_in_pool,
343+
mpi.my_pool,
344+
mpi.rank_in_pool);
345+
},
346+
testing::ExitedWithCode(1),
347+
"Error: Number of processes.*must be greater than the number of groups"
348+
);
201349
}
202350

203351
int main(int argc, char** argv)
204352
{
353+
bool is_death_test_child = false;
354+
for (int i = 0; i < argc; ++i) {
355+
if (std::string(argv[i]).find("gtest_internal_run_death_test") != std::string::npos) {
356+
is_death_test_child = true;
357+
break;
358+
}
359+
}
360+
361+
if (!is_death_test_child)
362+
{
363+
MPI_Init(&argc, &argv);
364+
}
205365

206-
MPI_Init(&argc, &argv);
207366
testing::InitGoogleTest(&argc, argv);
367+
testing::FLAGS_gtest_death_test_style = "threadsafe";
208368
int result = RUN_ALL_TESTS();
209-
MPI_Finalize();
369+
370+
if (!is_death_test_child) {
371+
MPI_Finalize();
372+
}
210373
return result;
211374
}
212375
#endif // __MPI

0 commit comments

Comments
 (0)