Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions graph_net/test/dtype_gen_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,3 @@ EOF
)


# Step 3: Valiation
SUCCESS_CNT=0
FAIL_CNT=0

for model_path in "$OUTPUT_DIR"/*; do
echo "[VALIDATE] $model_path"

output=$(python -m graph_net.torch.validate \
--model-path "$model_path" 2>&1)

if echo "$output" | grep -q "Validation success, model_path="; then
echo "SUCCESS"
((SUCCESS_CNT++))
else
echo "FAIL"
((FAIL_CNT++))
fi
done

echo "===================="
echo "SUCCESS $SUCCESS_CNT"
echo "FAIL $FAIL_CNT"
65 changes: 53 additions & 12 deletions graph_net/torch/sample_pass/dtype_generalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
from graph_net.sample_pass.sample_pass import SamplePass
from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin

from graph_net.hash_util import get_sha256_hash

# Weights that must remain float32 for numerical stability
FLOAT32_PRESERVED_WEIGHTS = {
"running_mean",
Expand Down Expand Up @@ -296,7 +298,7 @@ def sample_handled(self, rel_model_path: str) -> bool:
def __call__(self, rel_model_path: str):
self.resumable_handle_sample(rel_model_path)

def resume(self, model_path: str) -> List[str]:
def resume(self, rel_model_path: str) -> List[str]:
"""
Apply dtype passes to generate new samples.

Expand All @@ -308,24 +310,36 @@ def resume(self, model_path: str) -> List[str]:
"""
# Apply model_path_prefix if provided
if self.model_path_prefix:
model_path = str(Path(self.model_path_prefix) / model_path)
abs_model_path = str(Path(self.model_path_prefix) / rel_model_path)

# Read pass names from graph_net.json
dtype_pass_names = self._read_dtype_pass_names(model_path)
dtype_pass_names = self._read_dtype_pass_names(abs_model_path)

if not dtype_pass_names:
logging.warning(f"No dtype passes found in {model_path}/graph_net.json")
logging.warning(f"No dtype passes found in {abs_model_path}/graph_net.json")
return []

# Parse the computation graph
traced_model = parse_immutable_model_path_into_sole_graph_module(model_path)
traced_model = parse_immutable_model_path_into_sole_graph_module(abs_model_path)

# Copy the originl sample
files_copied = [
"model.py",
"graph_hash.txt",
"graph_net.json",
"weight_meta.py",
"input_meta.py",
"input_tensor_constraints.py",
"subgraph_sources.json",
]
self._copy_sample_files(rel_model_path, "float32", files_copied)

# Generate samples for each pass
generated_samples = []
for pass_name in dtype_pass_names:
try:
sample_dir = self._apply_pass_and_generate(
model_path, traced_model, pass_name
rel_model_path, traced_model, pass_name
)
generated_samples.append(sample_dir)
logging.info(f"Generated sample: {sample_dir}")
Expand Down Expand Up @@ -388,8 +402,7 @@ def _apply_pass_and_generate(
gm_modified = dtype_pass.rewrite(gm_copy)

# Generate output directory
model_name = Path(model_path).name
output_sample_dir = Path(self.output_dir) / f"{model_name}_{dtype}"
output_sample_dir = Path(self.output_dir) / dtype / model_path
output_sample_dir.mkdir(parents=True, exist_ok=True)

# Write modified model.py
Expand All @@ -398,11 +411,20 @@ def _apply_pass_and_generate(
with open(output_sample_dir / "model.py", "w") as f:
f.write(write_code)

# Write modified graph_hash.txt
model_hash = get_sha256_hash(model_code)
with open(output_sample_dir / "graph_hash.txt", "w") as f:
f.write(model_hash)

# Copy metadata files
for fname in ["graph_net.json", "weight_meta.py", "input_meta.py"]:
src = Path(model_path) / fname
if src.exists():
shutil.copy(src, output_sample_dir / fname)
files_copied = [
"graph_net.json",
"weight_meta.py",
"input_meta.py",
"input_tensor_constraints.py",
"subgraph_sources.json",
]
self._copy_sample_files(model_path, dtype, files_copied)

# Update graph_net.json with dtype information
self._update_sample_metadata(output_sample_dir, dtype)
Expand All @@ -429,6 +451,25 @@ def _update_sample_metadata(self, sample_dir: Path, dtype: str) -> None:
update_json(graph_net_json_path, kDtypeGeneralizationPrecision, dtype)
update_json(graph_net_json_path, kDtypeGeneralizationGenerated, True)

def _copy_sample_files(
self, rel_model_path: str, dtype: str, files_copied: list
) -> None:
"""
Copy files of sample.

Args:
rel_model_path: relative model path
"""
# Generate output directory
output_sample_dir = Path(self.output_dir) / dtype / rel_model_path
output_sample_dir.mkdir(parents=True, exist_ok=True)

# Copy files of original sample
for fname in files_copied:
src = Path(rel_model_path) / fname
if src.exists():
shutil.copy(src, output_sample_dir / fname)


class MultiDtypeFilter:
"""
Expand Down