diff --git a/graph_net/test/dtype_gen_test.sh b/graph_net/test/dtype_gen_test.sh index 31528ef5a..c8283b8ae 100755 --- a/graph_net/test/dtype_gen_test.sh +++ b/graph_net/test/dtype_gen_test.sh @@ -43,25 +43,3 @@ EOF ) -# Step 3: Valiation -SUCCESS_CNT=0 -FAIL_CNT=0 - -for model_path in "$OUTPUT_DIR"/*; do - echo "[VALIDATE] $model_path" - - output=$(python -m graph_net.torch.validate \ - --model-path "$model_path" 2>&1) - - if echo "$output" | grep -q "Validation success, model_path="; then - echo "SUCCESS" - ((SUCCESS_CNT++)) - else - echo "FAIL" - ((FAIL_CNT++)) - fi -done - -echo "====================" -echo "SUCCESS $SUCCESS_CNT" -echo "FAIL $FAIL_CNT" \ No newline at end of file diff --git a/graph_net/torch/sample_pass/dtype_generalizer.py b/graph_net/torch/sample_pass/dtype_generalizer.py index 3ed56b32e..a8fec4d3f 100644 --- a/graph_net/torch/sample_pass/dtype_generalizer.py +++ b/graph_net/torch/sample_pass/dtype_generalizer.py @@ -42,6 +42,8 @@ from graph_net.sample_pass.sample_pass import SamplePass from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin +from graph_net.hash_util import get_sha256_hash + # Weights that must remain float32 for numerical stability FLOAT32_PRESERVED_WEIGHTS = { "running_mean", @@ -296,7 +298,7 @@ def sample_handled(self, rel_model_path: str) -> bool: def __call__(self, rel_model_path: str): self.resumable_handle_sample(rel_model_path) - def resume(self, model_path: str) -> List[str]: + def resume(self, rel_model_path: str) -> List[str]: """ Apply dtype passes to generate new samples. @@ -308,24 +310,36 @@ def resume(self, model_path: str) -> List[str]: """ # Apply model_path_prefix if provided if self.model_path_prefix: - model_path = str(Path(self.model_path_prefix) / model_path) + abs_model_path = str(Path(self.model_path_prefix) / rel_model_path) # Read pass names from graph_net.json - dtype_pass_names = self._read_dtype_pass_names(model_path) + dtype_pass_names = self._read_dtype_pass_names(abs_model_path) if not dtype_pass_names: - logging.warning(f"No dtype passes found in {model_path}/graph_net.json") + logging.warning(f"No dtype passes found in {abs_model_path}/graph_net.json") return [] # Parse the computation graph - traced_model = parse_immutable_model_path_into_sole_graph_module(model_path) + traced_model = parse_immutable_model_path_into_sole_graph_module(abs_model_path) + + # Copy the originl sample + files_copied = [ + "model.py", + "graph_hash.txt", + "graph_net.json", + "weight_meta.py", + "input_meta.py", + "input_tensor_constraints.py", + "subgraph_sources.json", + ] + self._copy_sample_files(rel_model_path, "float32", files_copied) # Generate samples for each pass generated_samples = [] for pass_name in dtype_pass_names: try: sample_dir = self._apply_pass_and_generate( - model_path, traced_model, pass_name + rel_model_path, traced_model, pass_name ) generated_samples.append(sample_dir) logging.info(f"Generated sample: {sample_dir}") @@ -388,8 +402,7 @@ def _apply_pass_and_generate( gm_modified = dtype_pass.rewrite(gm_copy) # Generate output directory - model_name = Path(model_path).name - output_sample_dir = Path(self.output_dir) / f"{model_name}_{dtype}" + output_sample_dir = Path(self.output_dir) / dtype / model_path output_sample_dir.mkdir(parents=True, exist_ok=True) # Write modified model.py @@ -398,11 +411,20 @@ def _apply_pass_and_generate( with open(output_sample_dir / "model.py", "w") as f: f.write(write_code) + # Write modified graph_hash.txt + model_hash = get_sha256_hash(model_code) + with open(output_sample_dir / "graph_hash.txt", "w") as f: + f.write(model_hash) + # Copy metadata files - for fname in ["graph_net.json", "weight_meta.py", "input_meta.py"]: - src = Path(model_path) / fname - if src.exists(): - shutil.copy(src, output_sample_dir / fname) + files_copied = [ + "graph_net.json", + "weight_meta.py", + "input_meta.py", + "input_tensor_constraints.py", + "subgraph_sources.json", + ] + self._copy_sample_files(model_path, dtype, files_copied) # Update graph_net.json with dtype information self._update_sample_metadata(output_sample_dir, dtype) @@ -429,6 +451,25 @@ def _update_sample_metadata(self, sample_dir: Path, dtype: str) -> None: update_json(graph_net_json_path, kDtypeGeneralizationPrecision, dtype) update_json(graph_net_json_path, kDtypeGeneralizationGenerated, True) + def _copy_sample_files( + self, rel_model_path: str, dtype: str, files_copied: list + ) -> None: + """ + Copy files of sample. + + Args: + rel_model_path: relative model path + """ + # Generate output directory + output_sample_dir = Path(self.output_dir) / dtype / rel_model_path + output_sample_dir.mkdir(parents=True, exist_ok=True) + + # Copy files of original sample + for fname in files_copied: + src = Path(rel_model_path) / fname + if src.exists(): + shutil.copy(src, output_sample_dir / fname) + class MultiDtypeFilter: """