Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions graphcore/tools/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def do_rename_pure(self, old_path: str, new_path: str) -> Generator[P, Row, str]
...

def _validate(self, path: str) -> str:
r = pathlib.Path(path).resolve()
if not r.is_relative_to("/memories"):
r = pathlib.PurePosixPath(path)
if not r.is_absolute() or not r.is_relative_to("/memories"):
raise MemoryBackendError(f"{str(r)} is not rooted in /memories")

return str(r)
Expand Down Expand Up @@ -1087,13 +1087,6 @@ class FileSystemMemoryBackend(FileSystemBackendCommon, MemoryBackend[Never, Neve
def __init__(self, storage_folder: pathlib.Path, init_from: pathlib.Path | None = None):
super().__init__(storage_folder, init_from)

def _run_all[T](self, d: Generator[Never, Never, T]) -> T:
try:
next(d)
assert False
except StopIteration as e:
return e.value

@override
def _run_multi[T](self, d: Generator[Never, Never, T]) -> T:
return self._run_all(d)
Expand Down Expand Up @@ -1121,6 +1114,18 @@ async def _run_row[T](self, d: Generator[Never, Never, T]) -> T:
@override
async def _run_update[T](self, d: Generator[Never, Never, T]) -> T:
return self._run_all(d)

class InvalidPathError(RuntimeError):
def __init__(self, msg: str):
super().__init__(msg)
self.msg = msg

def _validate_path(s: str):
p = pathlib.PurePosixPath(s)
if not p.is_absolute():
raise InvalidPathError(f"Memory path: {s} is not an absolute path")
if not p.is_relative_to("/memories"):
raise InvalidPathError(f"Memory path: {s} is not rooted at /memories")

def _memory_tool_impl[R](
backend: MemoryToolImpl[R],
Expand All @@ -1133,6 +1138,7 @@ def _memory_tool_impl[R](
return missing_required("path")
elif args.file_text is None:
return missing_required("file_text")
_validate_path(args.path)
return backend.create(args.path, args.file_text)

case "delete":
Expand All @@ -1147,12 +1153,15 @@ def _memory_tool_impl[R](
return missing_required("insert_line")
elif args.insert_text is None:
return missing_required("insert_text")
_validate_path(args.path)
return backend.insert(args.path, args.insert_line, args.insert_text)
case "rename":
if args.old_path is None:
return missing_required("old_path")
elif args.new_path is None:
return missing_required("new_path")
_validate_path(args.old_path)
_validate_path(args.new_path)
return backend.rename(args.old_path, args.new_path)

case "str_replace":
Expand All @@ -1162,11 +1171,13 @@ def _memory_tool_impl[R](
return missing_required("old_str")
elif args.new_str is None:
return missing_required("new_str")
_validate_path(args.path)
return backend.str_replace(args.path, args.old_str, args.new_str)

case "view":
if args.path is None:
return missing_required("path")
_validate_path(args.path)
range : tuple[int, int] | None = None
if args.view_range is not None and len(args.view_range) >= 2:
range = (args.view_range[0], args.view_range[1])
Expand All @@ -1182,9 +1193,12 @@ class MemoryTool(WithAsyncImplementation[str], UnifiedMemorySchema):
"""
@override
async def run(self) -> str:
return await _memory_tool_impl(
backend, self, missing_required
)
try:
return await _memory_tool_impl(
backend, self, missing_required
)
except InvalidPathError as e:
return f"Error: {e.msg}"
return MemoryTool.as_tool("memory")

def memory_tool(backend: MemoryToolImpl[str]) -> BaseTool:
Expand All @@ -1200,7 +1214,10 @@ class MemoryTool(WithImplementation[str], UnifiedMemorySchema):
"""
@override
def run(self) -> str:
return _memory_tool_impl(
backend, self, missing_required
)
try:
return _memory_tool_impl(
backend, self, missing_required
)
except InvalidPathError as e:
return f"Error: {e.msg}"
return MemoryTool.as_tool("memory")
4 changes: 3 additions & 1 deletion graphcore/tools/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@

ST = TypeVar("ST")

T_RES = TypeVar("T_RES", bound=str | Command)
type BareResult = str | dict

T_RES = TypeVar("T_RES", bound=BareResult | list[BareResult] | Command)

class WithInjectedState(BaseModel, Generic[ST]):
state: Annotated[ST, InjectedState]
Expand Down
Loading
Loading