diff --git a/examples/common/common.hpp b/examples/common/common.hpp index 50f35aed8..3cdb1e4d2 100644 --- a/examples/common/common.hpp +++ b/examples/common/common.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include using json = nlohmann::json; @@ -205,11 +206,51 @@ static void log_printf(sd_log_level_t level, const char* file, int line, const c #define LOG_WARN(format, ...) log_printf(SD_LOG_WARN, __FILE__, __LINE__, format, ##__VA_ARGS__) #define LOG_ERROR(format, ...) log_printf(SD_LOG_ERROR, __FILE__, __LINE__, format, ##__VA_ARGS__) +class OptionFlags { + using value_t = uint8_t; +public: + enum type : value_t { + assigned = 0x1, // runtime flag; is set when an option is assigned a value from command line + no_network = 0x2 // options with this flag are not encoded for transfer in json (sd-client/sd-server) + }; + + void set(OptionFlags::type flag) { + vflags |= static_cast(flag); + } + + void set(std::initializer_list flags) { + for (OptionFlags::type f : flags) + vflags |= static_cast(f); + }; + + bool has(OptionFlags::type flag) const { + return vflags & static_cast(flag); + } + + bool has(std::initializer_list flags) const { + for (OptionFlags::type f : flags) + if (!(vflags & static_cast(f))) + return false; + return true; + }; + + OptionFlags() : vflags(0) {} + OptionFlags(OptionFlags::type flags) : vflags(0) { + set(flags); + } + OptionFlags(std::initializer_list flags) : vflags(0) { + set(flags); + } +private: + value_t vflags; +}; + struct StringOption { std::string short_name; std::string long_name; std::string desc; std::string* target; + OptionFlags flags; }; struct IntOption { @@ -217,6 +258,7 @@ struct IntOption { std::string long_name; std::string desc; int* target; + OptionFlags flags; }; struct FloatOption { @@ -224,6 +266,7 @@ struct FloatOption { std::string long_name; std::string desc; float* target; + OptionFlags flags; }; struct BoolOption { @@ -232,6 +275,7 @@ struct BoolOption { std::string desc; bool keep_true; bool* target; + OptionFlags flags; }; struct ManualOption { @@ -239,6 +283,7 @@ struct ManualOption { std::string long_name; std::string desc; std::function cb; + OptionFlags flags; }; struct ArgOptions { @@ -345,7 +390,69 @@ struct ArgOptions { } }; -static bool parse_options(int argc, const char** argv, const std::vector& options_list) { +/** given a params object and the parsed ArgOptions, convert to json for network send **/ +template +json options_to_json(PARAMS const& params, ArgOptions const& options) { + json opt_json = json::object(); + + auto get_opts = [&opt_json](std::vector const& v) { + for (auto const& opt : v) { + if (!opt.flags.has(OptionFlags::assigned) || opt.flags.has(OptionFlags::no_network)) + continue; // skip unassigned or no network options + std::string node = opt.long_name.substr(2); + std::replace(node.begin(), node.end(), '-', '_'); + opt_json[node] = *opt.target; + } + }; + + // automate transfer of simple types + get_opts(options.string_options); + get_opts(options.int_options); + get_opts(options.float_options); + get_opts(options.bool_options); + // manual options conversion hook + params.manual_options_to_json(options.manual_options, opt_json); + + return opt_json; +} + +/** given a params object and a json object, convert from json for network receive **/ +template +bool options_from_json(PARAMS& params, json const& opt_json) { + ArgOptions options = params.get_options(); // instantiate parameters for speculation + + auto set_opts = [&opt_json](std::vector& v) -> bool { + for (auto const& opt : v) { + if (opt.flags.has(OptionFlags::no_network)) + continue; // skip no network options + std::string node = opt.long_name.substr(2); + std::replace(node.begin(), node.end(), '-', '_'); + using T = std::decay_t; + if (opt_json.contains(node)) { + try { + *opt.target = opt_json[node].get(); + // probably only need this if want to rebuild json again + //opt.flags.set(OptionFlags::assigned); + } catch (...) { + LOG_ERROR("options_from_json: error: processing argument `%s`", node.c_str()); + return false; + } + } + } + return true; + }; + + // automate transfer of simple types + if (!set_opts(options.string_options) || + !set_opts(options.int_options) || + !set_opts(options.float_options) || + !set_opts(options.bool_options) || + !params.manual_options_from_json(options.manual_options, opt_json)) + return false; + return true; +} + +static bool parse_options(int argc, const char** argv, std::vector& options_list) { bool invalid_arg = false; std::string arg; @@ -354,6 +461,7 @@ static bool parse_options(int argc, const char** argv, const std::vector 0 && arg == option.short_name) || (option.long_name.size() > 0 && arg == option.long_name)) { apply_fn(option); + option.flags.set(OptionFlags::assigned); // mark option as assigned for options_to_json() return true; } } @@ -1099,33 +1207,33 @@ struct SDGenerationParams { {"-i", "--init-img", "path to the init image", - &init_image_path}, + &init_image_path, OptionFlags::no_network}, {"", "--end-img", "path to the end image, required by flf2v", - &end_image_path}, + &end_image_path, OptionFlags::no_network}, {"", "--mask", "path to the mask image", - &mask_image_path}, + &mask_image_path, OptionFlags::no_network}, {"", "--control-image", "path to control image, control net", - &control_image_path}, + &control_image_path, OptionFlags::no_network}, {"", "--control-video", "path to control video frames, It must be a directory path. The video frames inside should be stored as images in " "lexicographical (character) order. For example, if the control video path is `frames`, the directory contain images " "such as 00.png, 01.png, ... etc.", - &control_video_path}, + &control_video_path, OptionFlags::no_network}, {"", "--pm-id-images-dir", "path to PHOTOMAKER input id images dir", - &pm_id_images_dir}, + &pm_id_images_dir, OptionFlags::no_network}, {"", "--pm-id-embed-path", "path to PHOTOMAKER v2 id embed", - &pm_id_embed_path}, + &pm_id_embed_path, OptionFlags::no_network}, }; options.int_options = { @@ -1512,7 +1620,7 @@ struct SDGenerationParams { {"-r", "--ref-image", "reference image for Flux Kontext models (can be used multiple times)", - on_ref_image_arg}, + on_ref_image_arg, OptionFlags::no_network}, {"", "--cache-mode", "caching method: 'easycache' (DiT), 'ucache' (UNET), 'dbcache'/'taylorseer'/'cache-dit' (DiT block-level)", @@ -1539,95 +1647,145 @@ struct SDGenerationParams { return options; } - bool from_json_str(const std::string& json_str) { - json j; - try { - j = json::parse(json_str); - } catch (...) { - LOG_ERROR("json parse failed %s", json_str.c_str()); - return false; + +private: + /** build a set assigned options so we can check for availablity faster **/ + template + static std::unordered_set build_optlist(std::vector