Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion graph_net/test/backward_graph_extractor.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,30 @@
GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(
os.path.dirname(graph_net.__file__))")
GRAPHNET_ROOT="$GRAPH_NET_ROOT/../"

# Device rewrite pass
DEVICE_REWRITE_OUTPUT_DIR="/tmp/device_rewrited_samples"
mkdir -p "$DEVICE_REWRITE_OUTPUT_DIR"

python3 -m graph_net.model_path_handler \
--model-path-list "graph_net/config/small100_torch_samples_list.txt" \
--handler-config $(base64 -w 0 <<EOF
{
"handler_path": "$GRAPHNET_ROOT/graph_net/torch/sample_pass/device_rewrite_sample_pass.py",
"handler_class_name": "DeviceRewriteSamplePass",
"handler_config": {
"device": "cuda",
"resume": false,
"model_path_prefix": "$GRAPHNET_ROOT",
"output_dir": "$DEVICE_REWRITE_OUTPUT_DIR"
}
}
EOF
)

echo "Device rewrite pass completed!"

# Backward graph extraction
OUTPUT_DIR="/tmp/backward_graph_samples"
mkdir -p "$OUTPUT_DIR"

Expand All @@ -12,7 +36,7 @@ python3 -m graph_net.apply_sample_pass \
--sample-pass-class-name "BackwardGraphExtractorPass" \
--sample-pass-config $(base64 -w 0 <<EOF
{
"model_path_prefix": "$GRAPHNET_ROOT",
"model_path_prefix": "$DEVICE_REWRITE_OUTPUT_DIR",
"output_dir": "$OUTPUT_DIR",
"device": "cuda"
}
Expand Down
58 changes: 34 additions & 24 deletions graph_net/torch/sample_pass/backward_graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,39 +16,43 @@

class BackwardGraphExtractor:
def __init__(self, model_name, model_path, output_dir, device):
self.model_name = model_name
self.model_path = model_path
self.output_dir = output_dir
self.device = device
self.builtin_extractor = BuiltinGraphExtractor(
name=model_name,
dynamic=False,
mut_graph_codes=[],
placeholder_auto_rename=False,
workspace_path=output_dir,
)

def __call__(self):
module, example_inputs = get_torch_module_and_inputs(
module, forward_inputs = get_torch_module_and_inputs(
self.model_path, use_dummy_inputs=False, device=self.device
)
module.train()

example_inputs = self.set_requires_grad_for_forward_inputs(
self.model_path, module, example_inputs
forward_inputs = self.set_requires_grad_for_forward_inputs(
self.model_path, module, forward_inputs
)
bw_gm, backward_inputs = self.capture_backward_graph(module, example_inputs)
self.builtin_extractor(bw_gm, backward_inputs)
gm_holder, backward_inputs = self.capture_graph(module, forward_inputs)
self.get_extractor("forward")(gm_holder["forward_gm"], forward_inputs)
self.get_extractor("backward")(gm_holder["backward_gm"], backward_inputs)

def capture_backward_graph(self, module, example_inputs):
backward_gm_holder = {}
def get_extractor(self, suffix):
return BuiltinGraphExtractor(
name=f"{self.model_name}_{suffix}",
dynamic=False,
mut_graph_codes=[],
placeholder_auto_rename=False,
workspace_path=self.output_dir,
)

def capture_graph(self, module, forward_inputs):
gm_holder = {}
backward_inputs = []

def forward_compiler(fx_gm, example_inputs):
def forward_compiler(fx_gm, forward_inputs):
gm_holder["forward_gm"] = fx_gm
return fx_gm

def backward_compiler(fx_gm, example_inputs):
# Save the backward fx.Graph
backward_gm_holder["gm"] = fx_gm
def backward_compiler(fx_gm, forward_inputs):
gm_holder["backward_gm"] = fx_gm

placeholders = [n for n in fx_gm.graph.nodes if n.op == "placeholder"]
origin_forward = fx_gm.forward
Expand All @@ -63,11 +67,11 @@ def wrapped_forward(*args):

compiled = aot_module_simplified(
module,
example_inputs,
forward_inputs,
fw_compiler=forward_compiler,
bw_compiler=backward_compiler,
)
outs = compiled(*example_inputs)
outs = compiled(*forward_inputs)
outs = [outs] if isinstance(outs, torch.Tensor) else outs
valid_pairs = [
(out, torch.ones_like(out))
Expand All @@ -78,9 +82,11 @@ def wrapped_forward(*args):
if valid_pairs:
tensors, grads = zip(*valid_pairs)
torch.autograd.backward(tensors, grads)
gm_holder["backward_gm"] = self._remove_none_from_output(
gm_holder["backward_gm"]
)

bw_gm = self._remove_none_from_output(backward_gm_holder["gm"])
return bw_gm, backward_inputs
return gm_holder, backward_inputs

def _remove_none_from_output(self, gm):
output_node = next(
Expand Down Expand Up @@ -167,9 +173,13 @@ def sample_handled(self, rel_model_path: str) -> bool:

def resume(self, rel_model_path: str):
model_path_prefix = Path(self.config["model_path_prefix"])
model_name = f"{os.path.basename(rel_model_path)}_backward"
model_name = f"{os.path.basename(rel_model_path)}"
model_path = model_path_prefix / rel_model_path
output_dir = Path(self.config["output_dir"]) / os.path.dirname(rel_model_path)
output_dir = (
Path(self.config["output_dir"])
/ os.path.dirname(rel_model_path)
/ model_name
)
device = self._choose_device(self.config["device"])
extractor = BackwardGraphExtractor(model_name, model_path, output_dir, device)
extractor()
Expand Down
Loading