Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved parameter file parsing #107

Merged
merged 2 commits into from
Oct 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/deform/registration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,13 @@ int RegistrationCommand::_execute(void)
param_str << f.rdbuf();

LOG(Info) << "Running with parameter file: '" << param_file << "'";
if (!parse_registration_settings(param_str.str(), settings))
try {
parse_registration_settings(param_str.str(), settings);
}
catch (std::exception& e) {
LOG(Error) << e.what();
return EXIT_FAILURE;
}

// Print only contents of parameter file to Info
LOG(Info) << "Parameters:" << std::endl << param_str.str();
Expand Down
260 changes: 131 additions & 129 deletions src/deform_lib/registration/settings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,38 +219,34 @@ namespace YAML {
throw YAML::RepresentationException(node.Mark(), "expected image slot");
}

// Cost functions
auto& cf = node["cost_function"];
if (cf) {
if (cf.IsSequence()) {
out.cost_functions.resize(cf.size());
for(size_t k = 0; k < cf.size(); ++k) {
out.cost_functions[k] = cf[k].as<Settings::ImageSlot::WeightedFunction>();
for (const auto& c : node) {
std::string key = c.first.as<std::string>();
const YAML::Node& value = c.second;

if (key == "cost_function") {
if (value.IsSequence()) {
out.cost_functions.resize(value.size());
for(size_t k = 0; k < value.size(); ++k) {
out.cost_functions[k] = value[k].as<Settings::ImageSlot::WeightedFunction>();
}
}
else {
// NOTE: assuming that the constructor of ImageSlot initialises
// at least the first cost function in the array
out.cost_functions[0].function = value.as<Settings::ImageSlot::CostFunction>();
out.cost_functions[0].weight = 1.0f;
}
}
else {
// NOTE: assuming that the constructor of ImageSlot initialises
// at least the first cost function in the array
out.cost_functions[0].function = cf.as<Settings::ImageSlot::CostFunction>();
out.cost_functions[0].weight = 1.0f;
else if (key == "resampler") {
out.resample_method = value.as<Settings::ImageSlot::ResampleMethod>();
}
}

// Resampling method
if (node["resampler"]) {
out.resample_method = node["resampler"].as<Settings::ImageSlot::ResampleMethod>();
}

// Normalisation
if (node["normalize"]) {
try {
out.normalize = node["normalize"].as<bool>();
else if (key == "normalize") {
out.normalize = value.as<bool>();
}
catch (YAML::TypedBadConversion<bool> &) {
throw YAML::RepresentationException(node.Mark(), "expected bool");
else {
throw YAML::RepresentationException(node.Mark(), "Unrecognized image slot parameter: " + key);
}
}

return true;
}
};
Expand Down Expand Up @@ -303,47 +299,44 @@ const char* update_rule_to_str(Settings::UpdateRule op)
return "none";
}

static void parse_level(const YAML::Node& node, Settings::Level& out) {
if(!node.IsMap()) {
throw YAML::RepresentationException(node.Mark(), "expected level");
// Parses a level specific parameter
// Returns true if any parameter was actually set, false if not
static bool parse_level_parameter(
const std::string& key,
const YAML::Node& value,
Settings::Level& out
)
{
if (key == "block_size") {
out.block_size = value.as<int3>();
}

if (node["block_size"]) {
out.block_size = node["block_size"].as<int3>();
else if (key == "block_energy_epsilon") {
out.block_energy_epsilon = value.as<float>();
}

if (node["block_energy_epsilon"]) {
out.block_energy_epsilon = node["block_energy_epsilon"].as<float>();
else if (key == "max_iteration_count") {
out.max_iteration_count = value.as<int>();
}

if (node["max_iteration_count"]) {
out.max_iteration_count = node["max_iteration_count"].as<int>();
else if (key == "regularization_weight") {
out.regularization_weight = value.as<float>();
}

if (node["regularization_weight"]) {
out.regularization_weight = node["regularization_weight"].as<float>();
}

if (node["regularization_scale"]) {
out.regularization_scale = node["regularization_scale"].as<float>();
else if (key == "regularization_scale") {
out.regularization_scale = value.as<float>();
}

if (node["regularization_exponent"]) {
out.regularization_exponent = node["regularization_exponent"].as<float>();
else if (key == "regularization_exponent") {
out.regularization_exponent = value.as<float>();
}

if (node["step_size"]) {
else if (key == "step_size") {
try {
out.step_size = node["step_size"].as<float3>();
out.step_size = value.as<float3>();
}
catch (YAML::RepresentationException&) {
try {
float f = node["step_size"].as<float>();
float f = value.as<float>();
out.step_size = {f, f, f};
}
catch (YAML::RepresentationException&) {
throw YAML::RepresentationException(
node["step_size"].Mark(),
value.Mark(),
"expected float or sequence of three floats"
);
}
Expand All @@ -357,18 +350,19 @@ static void parse_level(const YAML::Node& node, Settings::Level& out) {
throw ValidationError("Settings: Invalid step_size, step_size should be greater than zero");
}
}

if (node["constraints_weight"]) {
out.constraints_weight = node["constraints_weight"].as<float>();
else if (key == "constraints_weight") {
out.constraints_weight = value.as<float>();
}

if (node["landmarks_weight"]) {
out.landmarks_weight = node["landmarks_weight"].as<float>();
else if (key == "landmarks_weight") {
out.landmarks_weight = value.as<float>();
}

if (node["landmarks_decay"]) {
out.landmarks_decay = node["landmarks_decay"].as<float>();
else if (key == "landmarks_decay") {
out.landmarks_decay = value.as<float>();
}
else {
return false;
}
return true;
}


Expand Down Expand Up @@ -422,55 +416,36 @@ void print_registration_settings(const Settings& settings, std::ostream& s)
}
}

bool parse_registration_settings(const std::string& str, Settings& settings)
void parse_registration_settings(const std::string& str, Settings& settings)
{
settings = {}; // Clean up

try {

YAML::Node root = YAML::Load(str);

YAML::Node root = YAML::Load(str);
// First pass we parse global level settings
Settings::Level global_level_settings;

if (root["pyramid_levels"]) {
settings.num_pyramid_levels = root["pyramid_levels"].as<int>();
}

if (root["pyramid_stop_level"]) {
settings.pyramid_stop_level = root["pyramid_stop_level"].as<int>();
}

// First parse global level settings
Settings::Level global_level_settings;
parse_level(root, global_level_settings);

// Apply global settings for all levels
settings.levels.resize(settings.num_pyramid_levels);
for (int i = 0; i < settings.num_pyramid_levels; ++i) {
settings.levels[i] = global_level_settings;
// Global settings not connected to specific levels
for (const auto& node : root) {
std::string key = node.first.as<std::string>();
const YAML::Node& value = node.second;

if (key == "pyramid_levels") {
settings.num_pyramid_levels = value.as<int>();
}

// Parse per-level overrides
auto levels = root["levels"];
if (levels) {
for (auto it = levels.begin(); it != levels.end(); ++it) {
int l = it->first.as<int>();
if (l >= settings.num_pyramid_levels) {
throw ValidationError("Settings: index of level exceed number specified in pyramid_levels");
}
parse_level(it->second, settings.levels[l]);
}
else if (key == "pyramid_stop_level") {
settings.pyramid_stop_level = value.as<int>();
}

if (root["landmarks_stop_level"]) {
settings.landmarks_stop_level = root["landmarks_stop_level"].as<int>();
else if (key == "landmarks_stop_level") {
settings.landmarks_stop_level = value.as<int>();
}

if (root["regularize_initial_displacement"]) {
settings.regularize_initial_displacement
= root["regularize_initial_displacement"].as<bool>();
else if (key == "regularize_initial_displacement") {
settings.regularize_initial_displacement = value.as<bool>();
}

if (root["solver"]) {
std::string solver = root["solver"].as<std::string>();
else if (key == "solver") {
std::string solver = value.as<std::string>();

if (solver == "icm") {
settings.solver = Settings::Solver_ICM;
Expand All @@ -492,10 +467,10 @@ bool parse_registration_settings(const std::string& str, Settings& settings)
else {
throw ValidationError("Settings: Invalid solver");
}
}

if (root["update_rule"]) {
std::string rule = root["update_rule"].as<std::string>();
}
else if (key == "update_rule") {
std::string rule = value.as<std::string>();
if (rule == "additive") {
settings.update_rule = Settings::UpdateRule_Additive;
}
Expand All @@ -507,8 +482,8 @@ bool parse_registration_settings(const std::string& str, Settings& settings)
for (int i = settings.pyramid_stop_level; i < settings.num_pyramid_levels; ++i) {
if (settings.levels[i].regularization_exponent != 2) {
LOG(Warning) << "Submodularity is only guaranteed for "
<< "regularization_exponent=2 when using the "
<< "compositive update rule";
<< "regularization_exponent=2 when using the "
<< "compositive update rule";
break;
}
}
Expand All @@ -517,36 +492,63 @@ bool parse_registration_settings(const std::string& str, Settings& settings)
throw ValidationError("Settings: Invalid update rule");
}
}

auto is = root["image_slots"];
if (is && is.IsSequence()) {
for (size_t i = 0; i < is.size(); ++i) {
settings.image_slots.push_back(is[i].as<Settings::ImageSlot>());
else if (parse_level_parameter(key, value, global_level_settings)) {
// parse_level_parameter does the parsing
}
else if (key == "image_slots") {
if (value.IsSequence()) {
for (size_t i = 0; i < value.size(); ++i) {
settings.image_slots.push_back(value[i].as<Settings::ImageSlot>());
}
} else {
throw ValidationError("Settings: Expeced 'image_slots' to be a sequence");
}
}
else if (key == "levels") {
// We parse levels in a second pass, to allow global settings to be set
}
else {
std::stringstream ss;
ss << "Settings: Unrecognized parameter: " << key;
throw ValidationError(ss.str());
}
}
catch (YAML::Exception& e) {
LOG(Error) << "[Settings] " << e.what();
return false;

// Apply global settings for all levels
settings.levels.resize(settings.num_pyramid_levels);
for (int i = 0; i < settings.num_pyramid_levels; ++i) {
settings.levels[i] = global_level_settings;
}

return true;
}
// Parse per-level overrides
auto levels = root["levels"];
if (levels) {
for (const auto& level : levels) {
int l = level.first.as<int>();
if (l >= settings.num_pyramid_levels) {
throw ValidationError("Settings: index of level exceed number specified in pyramid_levels");
}

bool parse_registration_file(const std::string& parameter_file, Settings& settings)
{
// Defaults
settings = Settings();
if(!level.second.IsMap()) {
throw YAML::RepresentationException(level.second.Mark(), "expected level");
}

std::ifstream f(parameter_file, std::ifstream::in);
if (!f.is_open()) {
LOG(Error) << "[Settings] Failed to open file '" << parameter_file << "'";
return false;
for (const auto& node : level.second) {
std::string key = node.first.as<std::string>();
if (!parse_level_parameter(key, node.second, settings.levels[l])) {
std::stringstream ss;
ss << "Settings: Unrecognized level parameter: " << node.first.as<std::string>();
throw ValidationError(ss.str());
}
}
}
}

std::stringstream ss;
ss << f.rdbuf();

return parse_registration_settings(ss.str(), settings);
}
catch (YAML::Exception& e) {
std::stringstream ss;
ss << "Settings: " << e.what();

throw ValidationError(ss.str());
}
}

6 changes: 2 additions & 4 deletions src/deform_lib/registration/settings.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,8 @@ struct Settings
void print_registration_settings(const Settings& settings, std::ostream& s);

// Return true if parsing was successful, false if not
bool parse_registration_file(const std::string& parameter_file, Settings& settings);

// Return true if parsing was successful, false if not
bool parse_registration_settings(const std::string& str, Settings& settings);
// throws ValidationError if settings string is invalid
void parse_registration_settings(const std::string& str, Settings& settings);

// Returns the string representation of a given solver
const char* solver_to_str(Settings::Solver solver);
Expand Down
1 change: 1 addition & 0 deletions src/python_wrapper/_pydeform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ stk::Volume registration_wrapper(
// Parse settings
Settings settings_;
if (!settings.is_none()) {
// Convert the python dict into a YAML string, which then is parseable by settings
py::object py_yaml_dump = py::module::import("yaml").attr("dump");
py::object py_settings_str = py_yaml_dump(py::cast<py::dict>(settings));
std::string settings_str = py::cast<std::string>(py_settings_str);
Expand Down
Loading