Skip to content

Commit b78357b

Browse files
authoredJul 15, 2024··
Clean up handling of temporary files in test suite (lsils#652)
* Don't write `test.aig` from `write_aiger` test I assume this was added for debugging purposes at some point. The test doesn't need it and the current working directory might not be writable. * Set the current working directory to a temporary directory when running tests The incoming current working directory may not be writable. * Propagate I/O errors when (de)serializing This is especially important when deserializing. If we don't catch errors after reading `size`, we may use its uninitialized value as a loop bound, which is very bad. * Add tests for I/O error propagation
1 parent da7c921 commit b78357b

File tree

5 files changed

+229
-34
lines changed

5 files changed

+229
-34
lines changed
 

‎include/mockturtle/io/serialize.hpp

+83-20
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
#include "../networks/aig.hpp"
4444
#include <fstream>
45+
#include <optional>
4546
#include <parallel_hashmap/phmap_dump.h>
4647

4748
namespace mockturtle
@@ -93,7 +94,10 @@ struct serializer
9394
bool operator()( phmap::BinaryOutputArchive& os, regular_node<Fanin, Size, PointerFieldSize> const& n ) const
9495
{
9596
uint64_t size = n.children.size();
96-
os.dump( (char*)&size, sizeof( uint64_t ) );
97+
if ( !os.dump( (char*)&size, sizeof( uint64_t ) ) )
98+
{
99+
return false;
100+
}
97101

98102
for ( const auto& c : n.children )
99103
{
@@ -105,7 +109,10 @@ struct serializer
105109
}
106110

107111
size = n.data.size();
108-
os.dump( (char*)&size, sizeof( uint64_t ) );
112+
if ( !os.dump( (char*)&size, sizeof( uint64_t ) ) )
113+
{
114+
return false;
115+
}
109116
for ( const auto& d : n.data )
110117
{
111118
bool result = this->operator()( os, d );
@@ -122,7 +129,11 @@ struct serializer
122129
bool operator()( phmap::BinaryInputArchive& ar_input, const regular_node<Fanin, Size, PointerFieldSize>* n ) const
123130
{
124131
uint64_t size;
125-
ar_input.load( (char*)&size, sizeof( uint64_t ) );
132+
if ( !ar_input.load( (char*)&size, sizeof( uint64_t ) ) )
133+
{
134+
return false;
135+
}
136+
126137
for ( uint64_t i = 0; i < size; ++i )
127138
{
128139
pointer_type ptr;
@@ -163,7 +174,10 @@ struct serializer
163174
{
164175
/* nodes */
165176
uint64_t size = storage.nodes.size();
166-
os.dump( (char*)&size, sizeof( uint64_t ) );
177+
if ( !os.dump( (char*)&size, sizeof( uint64_t ) ) )
178+
{
179+
return false;
180+
}
167181
for ( const auto& n : storage.nodes )
168182
{
169183
if ( !this->operator()( os, n ) )
@@ -174,7 +188,10 @@ struct serializer
174188

175189
/* inputs */
176190
size = storage.inputs.size();
177-
os.dump( (char*)&size, sizeof( uint64_t ) );
191+
if ( !os.dump( (char*)&size, sizeof( uint64_t ) ) )
192+
{
193+
return false;
194+
}
178195
for ( const auto& i : storage.inputs )
179196
{
180197
if ( !this->operator()( os, i ) )
@@ -185,7 +202,10 @@ struct serializer
185202

186203
/* outputs */
187204
size = storage.outputs.size();
188-
os.dump( (char*)&size, sizeof( uint64_t ) );
205+
if ( !os.dump( (char*)&size, sizeof( uint64_t ) ) )
206+
{
207+
return false;
208+
}
189209
for ( const auto& o : storage.outputs )
190210
{
191211
if ( !this->operator()( os, o ) )
@@ -200,7 +220,10 @@ struct serializer
200220
return false;
201221
}
202222

203-
os.dump( (char*)&storage.trav_id, sizeof( uint32_t ) );
223+
if ( !os.dump( (char*)&storage.trav_id, sizeof( uint32_t ) ) )
224+
{
225+
return false;
226+
}
204227

205228
return true;
206229
}
@@ -209,7 +232,10 @@ struct serializer
209232
{
210233
/* nodes */
211234
uint64_t size;
212-
ar_input.load( (char*)&size, sizeof( uint64_t ) );
235+
if ( !ar_input.load( (char*)&size, sizeof( uint64_t ) ) )
236+
{
237+
return false;
238+
}
213239
for ( uint64_t i = 0; i < size; ++i )
214240
{
215241
node_type n;
@@ -221,16 +247,25 @@ struct serializer
221247
}
222248

223249
/* inputs */
224-
ar_input.load( (char*)&size, sizeof( uint64_t ) );
250+
if ( !ar_input.load( (char*)&size, sizeof( uint64_t ) ) )
251+
{
252+
return false;
253+
}
225254
for ( uint64_t i = 0; i < size; ++i )
226255
{
227256
uint64_t value;
228-
ar_input.load( (char*)&value, sizeof( uint64_t ) );
257+
if ( !ar_input.load( (char*)&value, sizeof( uint64_t ) ) )
258+
{
259+
return false;
260+
}
229261
storage->inputs.push_back( value );
230262
}
231263

232264
/* outputs */
233-
ar_input.load( (char*)&size, sizeof( uint64_t ) );
265+
if ( !ar_input.load( (char*)&size, sizeof( uint64_t ) ) )
266+
{
267+
return false;
268+
}
234269
for ( uint64_t i = 0; i < size; ++i )
235270
{
236271
pointer_type ptr;
@@ -247,23 +282,36 @@ struct serializer
247282
return false;
248283
}
249284

250-
ar_input.load( (char*)&storage->trav_id, sizeof( uint32_t ) );
285+
if ( !ar_input.load( (char*)&storage->trav_id, sizeof( uint32_t ) ) )
286+
{
287+
return false;
288+
}
251289

252290
return true;
253291
}
254292
}; /* struct serializer */
255293

256294
} /* namespace detail */
257295

296+
/*! \brief Serializes a combinational AIG network to a archive, returning false on failure
297+
*
298+
* \param aig Combinational AIG network
299+
* \param os Output archive
300+
*/
301+
inline bool serialize_network_fallible( aig_network const& aig, phmap::BinaryOutputArchive& os )
302+
{
303+
detail::serializer _serializer;
304+
return _serializer( os, *aig._storage );
305+
}
306+
258307
/*! \brief Serializes a combinational AIG network to a archive
259308
*
260309
* \param aig Combinational AIG network
261310
* \param os Output archive
262311
*/
263312
inline void serialize_network( aig_network const& aig, phmap::BinaryOutputArchive& os )
264313
{
265-
detail::serializer _serializer;
266-
bool const okay = _serializer( os, *aig._storage );
314+
bool const okay = serialize_network_fallible( aig, os );
267315
(void)okay;
268316
assert( okay && "failed to serialize the network onto stream" );
269317
}
@@ -279,12 +327,12 @@ inline void serialize_network( aig_network const& aig, std::string const& filena
279327
serialize_network( aig, ar_out );
280328
}
281329

282-
/*! \brief Deserializes a combinational AIG network from a input archive
330+
/*! \brief Deserializes a combinational AIG network from a input archive, returning nullopt on failure
283331
*
284332
* \param ar_input Input archive
285333
* \return Deserialized AIG network
286334
*/
287-
inline aig_network deserialize_network( phmap::BinaryInputArchive& ar_input )
335+
inline std::optional<aig_network> deserialize_network_fallible( phmap::BinaryInputArchive& ar_input )
288336
{
289337
detail::serializer _serializer;
290338
auto storage = std::make_shared<aig_storage>();
@@ -293,10 +341,25 @@ inline aig_network deserialize_network( phmap::BinaryInputArchive& ar_input )
293341
storage->outputs.clear();
294342
storage->hash.clear();
295343

296-
bool const okay = _serializer( ar_input, storage.get() );
297-
(void)okay;
298-
assert( okay && "failed to deserialize the network onto stream" );
299-
return aig_network{ storage };
344+
if ( _serializer( ar_input, storage.get() ) )
345+
{
346+
return aig_network{ storage };
347+
}
348+
349+
return std::nullopt;
350+
}
351+
352+
/*! \brief Deserializes a combinational AIG network from a input archive
353+
*
354+
* \param ar_input Input archive
355+
* \return Deserialized AIG network
356+
*/
357+
inline aig_network deserialize_network( phmap::BinaryInputArchive& ar_input )
358+
{
359+
auto result = deserialize_network_fallible( ar_input );
360+
(void)result.has_value();
361+
assert( result.has_value() && "failed to deserialize the network onto stream" );
362+
return *result;
300363
}
301364

302365
/*! \brief Deserializes a combinational AIG network from a file

‎lib/parallel_hashmap/parallel_hashmap/phmap_dump.h

+28-8
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include <iostream>
2323
#include <fstream>
24+
#include <limits>
2425
#include <sstream>
2526
#include "phmap.h"
2627
namespace phmap
@@ -188,47 +189,66 @@ bool parallel_hash_set<N, RefSet, Mtx_, Policy, Hash, Eq, Alloc>::load(InputArch
188189
// ------------------------------------------------------------------------
189190
class BinaryOutputArchive {
190191
public:
191-
BinaryOutputArchive(const char *file_path) {
192+
BinaryOutputArchive(const char *file_path,
193+
size_t bytes_remaining = std::numeric_limits<size_t>::max())
194+
: bytes_remaining_(bytes_remaining) {
192195
ofs_.open(file_path, std::ios_base::binary);
193196
}
194197

195198
bool dump(const char *p, size_t sz) {
199+
if ( sz > bytes_remaining_ ) {
200+
bytes_remaining_ = 0;
201+
return false;
202+
}
203+
bytes_remaining_ -= sz;
196204
ofs_.write(p, sz);
197-
return true;
205+
return ofs_.good();
198206
}
199207

200208
template<typename V>
201209
typename std::enable_if<type_traits_internal::IsTriviallyCopyable<V>::value, bool>::type
202210
dump(const V& v) {
203-
ofs_.write(reinterpret_cast<const char *>(&v), sizeof(V));
204-
return true;
211+
return dump(reinterpret_cast<const char *>(&v), sizeof(V));
212+
}
213+
214+
bool close() {
215+
ofs_.close();
216+
return ofs_.good();
205217
}
206218

207219
private:
208220
std::ofstream ofs_;
221+
size_t bytes_remaining_;
209222
};
210223

211224

212225
class BinaryInputArchive {
213226
public:
214-
BinaryInputArchive(const char * file_path) {
227+
BinaryInputArchive(const char * file_path,
228+
size_t bytes_remaining = std::numeric_limits<size_t>::max())
229+
: bytes_remaining_(bytes_remaining) {
215230
ifs_.open(file_path, std::ios_base::binary);
216231
}
217232

218233
bool load(char* p, size_t sz) {
234+
if ( sz > bytes_remaining_ ) {
235+
bytes_remaining_ = 0;
236+
return false;
237+
}
238+
bytes_remaining_ -= sz;
219239
ifs_.read(p, sz);
220-
return true;
240+
return ifs_.good();
221241
}
222242

223243
template<typename V>
224244
typename std::enable_if<type_traits_internal::IsTriviallyCopyable<V>::value, bool>::type
225245
load(V* v) {
226-
ifs_.read(reinterpret_cast<char *>(v), sizeof(V));
227-
return true;
246+
return load(reinterpret_cast<char *>(v), sizeof(V));
228247
}
229248

230249
private:
231250
std::ifstream ifs_;
251+
size_t bytes_remaining_;
232252
};
233253

234254
} // namespace phmap

‎test/io/serialize.cpp

+78-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
11
#include <catch.hpp>
22

3+
#include <filesystem>
4+
35
#include <mockturtle/io/serialize.hpp>
46

57
using namespace mockturtle;
68

9+
#if __GNUC__ == 7
10+
namespace fs = std::experimental::filesystem::v1;
11+
#else
12+
namespace fs = std::filesystem;
13+
#endif
14+
15+
static constexpr char file_name[] = "aig.dmp" ;
16+
717
TEST_CASE( "serialize aig_network into a file", "[serialize]" )
818
{
919
aig_network aig;
@@ -19,10 +29,10 @@ TEST_CASE( "serialize aig_network into a file", "[serialize]" )
1929
aig.create_po( f5 );
2030

2131
/* serialize */
22-
serialize_network( aig, "aig.dmp" );
32+
serialize_network( aig, file_name );
2333

2434
/* deserialize */
25-
aig_network aig2 = deserialize_network( "aig.dmp" );
35+
aig_network aig2 = deserialize_network( file_name );
2636

2737
CHECK( aig.size() == aig2.size() );
2838
CHECK( aig.num_cis() == aig2.num_cis() );
@@ -46,3 +56,69 @@ TEST_CASE( "serialize aig_network into a file", "[serialize]" )
4656
CHECK( aig2._storage->nodes[f5.index].children[0u].index == f4.index );
4757
CHECK( aig2._storage->nodes[f5.index].children[1u].index == f3.index );
4858
}
59+
60+
static aig_network create_network()
61+
{
62+
aig_network aig;
63+
64+
const auto a = aig.create_pi();
65+
const auto b = aig.create_pi();
66+
67+
const auto f1 = aig.create_nand( a, b );
68+
const auto f3 = aig.create_nand( b, f1 );
69+
const auto f4 = aig.create_nand( a, f1 );
70+
const auto f5 = aig.create_nand( f4, f3 );
71+
aig.create_po( f5 );
72+
73+
return aig;
74+
}
75+
76+
// These numbers were chosen to get 100% coverage of the `return false`
77+
// error paths in `serialize.hpp`.
78+
//
79+
// To find a value that gives coverage of a particular `return false`
80+
// statement, change the loops below to iterate from 0 to 1000000,
81+
// configure with `-DCMAKE_BUILD_TYPE=DEBUG` and run in a debugger
82+
// with a breakpoint set at the line of interest. When the breakpoint
83+
// is hit, get the value of `size` from its stack frame and add it
84+
// to this list.
85+
static constexpr int truncate_sizes[] =
86+
{
87+
0, 8, 16, 32, 40, 344, 352, 368, 376, 384, 672120
88+
};
89+
90+
TEST_CASE( "write errors are propagated", "[serialize]" )
91+
{
92+
aig_network aig = create_network();
93+
94+
serialize_network( aig, file_name );
95+
size_t file_size = fs::file_size( file_name );
96+
INFO("File size " << file_size);
97+
98+
for ( size_t size : truncate_sizes ) {
99+
if ( size >= file_size )
100+
{
101+
break;
102+
}
103+
phmap::BinaryOutputArchive output ( file_name, size );
104+
CHECK_FALSE( serialize_network_fallible( aig, output ) );
105+
}
106+
}
107+
108+
TEST_CASE( "read errors are propagated", "[serialize]" )
109+
{
110+
aig_network aig = create_network();
111+
112+
serialize_network( aig, file_name );
113+
size_t file_size = fs::file_size( file_name );
114+
INFO("File size " << file_size);
115+
116+
for ( size_t size : truncate_sizes ) {
117+
if ( size >= file_size )
118+
{
119+
break;
120+
}
121+
phmap::BinaryInputArchive input ( file_name, size );
122+
CHECK_FALSE( deserialize_network_fallible( input ).has_value() );
123+
}
124+
}

‎test/io/write_aiger.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ TEST_CASE( "write AIG for XOR into AIGERfile", "[write_aiger]" )
7878
seq_buffer<char> buffer;
7979
std::ostream os( &buffer );
8080
write_aiger( aig, os );
81-
write_aiger( aig, "test.aig" );
8281

8382
CHECK( buffer.data() ==
8483
std::vector<char>{

‎test/test.cpp

+40-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,43 @@
1-
#define CATCH_CONFIG_MAIN
1+
#define CATCH_CONFIG_RUNNER
2+
23
#include <catch.hpp>
34

4-
#include <mockturtle/mockturtle.hpp>
5+
#include <filesystem>
6+
#include <random>
7+
8+
#include <fmt/format.h>
9+
10+
#if __GNUC__ == 7
11+
namespace fs = std::experimental::filesystem::v1;
12+
#else
13+
namespace fs = std::filesystem;
14+
#endif
15+
16+
// Insecure but portable creation of a temporary directory
17+
static fs::path temp_directory()
18+
{
19+
std::random_device rd;
20+
std::mt19937_64 generator( rd() );
21+
std::uniform_int_distribution<uint64_t> distribution( 0, std::numeric_limits<uint64_t>::max() );
22+
uint64_t random_number = distribution( generator );
23+
24+
fs::path path = fs::temp_directory_path();
25+
path /= fmt::format( "mockturtle_test_{:x}", random_number );
26+
return path;
27+
}
28+
29+
int main(int argc, char* argv[])
30+
{
31+
auto temp_dir = temp_directory();
32+
fs::create_directory( temp_dir );
33+
fs::current_path( temp_dir );
34+
35+
int exit_code = Catch::Session().run(argc, argv);
36+
37+
if ( !exit_code ) {
38+
fs::current_path( ".." );
39+
fs::remove_all( temp_dir );
40+
}
541

6-
#include <iostream>
42+
return exit_code;
43+
}

0 commit comments

Comments
 (0)
Please sign in to comment.