diff --git a/graph_net/test/backward_graph_extractor.sh b/graph_net/test/backward_graph_extractor.sh index 73819cf2c..1a12ab0f5 100644 --- a/graph_net/test/backward_graph_extractor.sh +++ b/graph_net/test/backward_graph_extractor.sh @@ -3,6 +3,30 @@ GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print( os.path.dirname(graph_net.__file__))") GRAPHNET_ROOT="$GRAPH_NET_ROOT/../" + +# Device rewrite pass +DEVICE_REWRITE_OUTPUT_DIR="/tmp/device_rewrited_samples" +mkdir -p "$DEVICE_REWRITE_OUTPUT_DIR" + +python3 -m graph_net.model_path_handler \ + --model-path-list "graph_net/config/small100_torch_samples_list.txt" \ + --handler-config $(base64 -w 0 < bool: def resume(self, rel_model_path: str): model_path_prefix = Path(self.config["model_path_prefix"]) - model_name = f"{os.path.basename(rel_model_path)}_backward" + model_name = f"{os.path.basename(rel_model_path)}" model_path = model_path_prefix / rel_model_path - output_dir = Path(self.config["output_dir"]) / os.path.dirname(rel_model_path) + output_dir = ( + Path(self.config["output_dir"]) + / os.path.dirname(rel_model_path) + / model_name + ) device = self._choose_device(self.config["device"]) extractor = BackwardGraphExtractor(model_name, model_path, output_dir, device) extractor()