Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/codegen/sdk/core/expressions/unpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def unwrap(self, node: Expression | None = None) -> None:

Unwraps the content of a node by removing its wrapping syntax and merging its content with its parent node.
Specifically handles dictionary unwrapping, maintaining proper indentation and formatting.
Supports multiple spread elements and maintains their order.

Args:
node (Expression | None): The node to unwrap. If None, uses the instance's value node.
Expand All @@ -40,7 +41,7 @@ def unwrap(self, node: Expression | None = None) -> None:
"""
from codegen.sdk.core.symbol_groups.dict import Dict

node = node or self._value_node
node = node or self._value_node.resolved_value
if isinstance(node, Dict) and isinstance(self.parent, Dict):
if self.start_point[0] != self.parent.start_point[0]:
self.remove(delete_formatting=False)
Expand All @@ -54,10 +55,13 @@ def unwrap(self, node: Expression | None = None) -> None:
else:
# Delete the remaining characters on this line
self.remove_byte_range(self.end_byte, next_sibling.start_byte - next_sibling.start_point[1])

else:
self.remove()
for k, v in node.items():
self.parent[k] = v.source.strip()
if node.unpack:
self.parent._underlying.append(self.node.unpack.source)

# Add all items from the unwrapped dictionary
for child in node._underlying:
if isinstance(child, Unpack):
self.parent._underlying.append(child)
self.parent.unpacks.append(child)
else: # Regular key-value pair
self.parent._underlying.append(child)
257 changes: 241 additions & 16 deletions src/codegen/sdk/core/symbol_groups/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,23 @@ class Dict(Expression[Parent], Builtin, MutableMapping[str, TExpression], Generi
"""

_underlying: Collection[Pair[TExpression, Self] | Unpack[Self], Parent]
unpack: Unpack[Self] | None = None
unpacks: list[Unpack[Self]] = [] # Store all spread elements

def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent, delimiter: str = ",", pair_type: type[Pair] = Pair) -> None:
# TODO: handle spread_element
super().__init__(ts_node, file_node_id, ctx, parent)
children = [pair_type(child, file_node_id, ctx, self) for child in ts_node.named_children if child.type not in (None, "comment", "spread_element", "dictionary_splat") and not child.is_error]
if unpack := self.child_by_field_types({"spread_element", "dictionary_splat"}):
children.append(unpack)
self.unpack = unpack
children = []
self.unpacks = [] # Store all spread elements

for child in ts_node.named_children:
if child.type in (None, "comment") or child.is_error:
continue
if child.type in ("spread_element", "dictionary_splat"):
unpack = Unpack(child, file_node_id, ctx, self)
children.append(unpack)
self.unpacks.append(unpack) # Keep track of all spread elements
else:
children.append(pair_type(child, file_node_id, ctx, self))

if len(children) > 1:
first_child = children[0].ts_node.end_byte - ts_node.start_byte
second_child = children[1].ts_node.start_byte - ts_node.start_byte
Expand All @@ -107,26 +115,46 @@ def __bool__(self) -> bool:
def __len__(self) -> int:
return len(list(elem for elem in self._underlying if isinstance(elem, Pair)))

def _get_unpacked_items(self) -> Iterator[tuple[str, TExpression]]:
"""Get key-value pairs from all unpacked dictionaries."""
for unpack in self.unpacks:
if isinstance(unpack.value, Dict):
yield from unpack.value.items()

def __iter__(self) -> Iterator[str]:
# First yield keys from regular pairs
for pair in self._underlying:
if isinstance(pair, Pair):
if pair.key is not None:
if isinstance(pair.key, String):
yield pair.key.content
else:
yield pair.key.source
# Then yield keys from unpacked dictionaries
for unpack in self.unpacks:
for key, _ in self._get_unpacked_items():
yield key

def __getitem__(self, __key) -> TExpression:
for pair in self._underlying:
if isinstance(pair, Pair):
if isinstance(pair.key, String):
if pair.key.content == str(__key):
return pair.value
elif pair.key is not None:
if pair.key.source == str(__key):
return pair.value
msg = f"Key {__key} not found in {list(self.keys())} {self._underlying!r}"
raise KeyError(msg)
# First try regular pairs
try:
for pair in self._underlying:
if isinstance(pair, Pair):
if isinstance(pair.key, String):
if pair.key.content == str(__key):
return pair.value
elif pair.key is not None:
if pair.key.source == str(__key):
return pair.value
# Then try unpacked dictionaries
for unpack in self.unpacks:
for key, value in self._get_unpacked_items():
if key == str(__key):
return value
raise KeyError
except KeyError:
msg = f"Key {__key} not found in {list(self.keys())} {self._underlying!r}"
raise KeyError(msg)

def __setitem__(self, __key, __value: TExpression) -> None:
new_value = __value.source if isinstance(__value, Editable) else str(__value)
Expand Down Expand Up @@ -178,3 +206,200 @@ def descendant_symbols(self) -> list["Importable"]:
@property
def __class__(self):
return dict

def __repr__(self) -> str:
"""Return a string representation of the dictionary including spread elements."""
items = []

# Add spread elements in their original position
for child in self._underlying:
if isinstance(child, Unpack):
items.append(child.source)
else: # Regular key-value pair
if child.key is not None:
if isinstance(child.key, String):
key = child.key.content
else:
key = child.key.source
items.append(f"{key}: {child.value.source}")

return "{" + ", ".join(items) + "}"

def __str__(self) -> str:
return self.__repr__()

def _get_all_unpacks_and_keys(self, seen_unpacks: set, seen_keys: set) -> None:
"""Recursively get all unpacks and their keys from this dictionary and its dependencies.

Args:
seen_unpacks: Set to store all found unpacks
seen_keys: Set to store all keys from unpacked dictionaries
"""
for child in self._underlying:
if isinstance(child, Unpack):
# Get the name being unpacked (e.g., "base1" from "**base1")
unpack_name = child.source.strip("*")
seen_unpacks.add(unpack_name)

unpacked_dict = self.file.get_symbol(unpack_name).value
if isinstance(unpacked_dict, Dict):
# Add all keys from this dict
for unpacked_child in unpacked_dict._underlying:
if not isinstance(unpacked_child, Unpack) and unpacked_child.key is not None:
seen_keys.add(unpacked_child.key.source)
# Recursively check its unpacks
unpacked_dict._get_all_unpacks_and_keys(seen_unpacks, seen_keys)

def merge(self, *others: "Dict[TExpression, Parent] | str") -> None:
"""Merge multiple dictionaries into a new dictionary

Preserves spread operators and function calls in their original form.
Later dictionaries take precedence over earlier ones for duplicate keys.

Args:
*others: Other Dict objects or dictionary strings.
The strings can be either Python dicts (e.g. "{'x': 1}")
or TypeScript objects (e.g. "{x: 1}")

Raises:
ValueError: If attempting to merge dictionaries with duplicate keys or unpacks

Returns:
None
"""
# Track seen keys and unpacks to prevent duplicates
seen_keys = set()
seen_unpacks = set()

# Get all unpacks and their keys from its dependencies
self._get_all_unpacks_and_keys(seen_unpacks, seen_keys)

# Keep track of all items in order
merged_items = []

# First add all items from this dictionary
for child in self._underlying:
if isinstance(child, Unpack):
unpack_source = child.source
merged_items.append(unpack_source)
else: # Regular key-value pair
if child.key is not None:
key = child.key.source
if key in seen_keys:
msg = f"Duplicate key found: {key}"
raise ValueError(msg)
seen_keys.add(key)
merged_items.append(f"{key}: {child.value.source}")

# Add items from other dictionaries
for other in others:
if isinstance(other, Dict):
# Handle Dict objects from our SDK
for child in other._underlying:
if isinstance(child, Unpack):
unpack_source = child.source
# Get the name being unpacked (e.g., "base1" from "**base1")
unpack_name = unpack_source.strip("*")
if unpack_name in seen_unpacks:
msg = f"Duplicate unpack found: {unpack_source}"
raise ValueError(msg)
seen_unpacks.add(unpack_name)
merged_items.append(unpack_source)
else: # Regular key-value pair
if child.key is not None:
key = child.key.source
if key in seen_keys:
msg = f"Duplicate key found: {key}"
raise ValueError(msg)
seen_keys.add(key)
merged_items.append(f"{key}: {child.value.source}")
elif isinstance(other, str):
# Handle dictionary string
# Strip curly braces and whitespace
content = other.strip().strip("{}").strip()
if not content: # Skip empty dicts
continue

# Parse the content to check for duplicates
parts = content.split(",")
for part in parts:
part = part.strip()
if part.startswith("**"):
# Get the name being unpacked (e.g., "base1" from "**base1")
unpack_name = part.strip("*").strip() # Fix unpack name extraction
if unpack_name in seen_unpacks:
msg = f"Duplicate unpack found: {part}"
raise ValueError(msg)

unpacked_dict = self.file.get_symbol(unpack_name).value
if isinstance(unpacked_dict, Dict):
# Add all keys from this dict
for unpacked_child in unpacked_dict._underlying:
if not isinstance(unpacked_child, Unpack) and unpacked_child.key is not None:
if unpacked_child.key.source in seen_keys:
msg = f"Duplicate key found: {unpacked_child.key.source}"
raise ValueError(msg)

seen_unpacks.add(unpack_name)
merged_items.append(part)
else:
# It's a key-value pair
key = part.split(":")[0].strip()
if key in seen_keys:
msg = f"Duplicate key found: {key}"
raise ValueError(msg)
seen_keys.add(key)
merged_items.append(part)
else:
msg = f"Cannot merge with object of type {type(other)}"
raise TypeError(msg)

# Create merged source
merged_source = "{" + ", ".join(merged_items) + "}"

# Replace this dict's source with merged source
self.edit(merged_source)

def add(self, typescript_dict: str) -> None:
"""Add a TypeScript dictionary string to this dictionary

Args:
typescript_dict: A TypeScript dictionary string e.g. "{a: 1, b: 2}"

Returns:
None
"""
# Get current items
merged_items = []

# Add all items from this dictionary first
for child in self._underlying:
if isinstance(child, Unpack):
merged_items.append(child.source)
elif child.key is not None:
merged_items.append(f"{child.key.source}: {child.value.source}")

# Add the TypeScript dictionary content
typescript_dict = typescript_dict.strip().strip("{}").strip()
if typescript_dict: # Only add if not empty
merged_items.append(typescript_dict)

# Create merged source
merged_source = "{" + ", ".join(merged_items) + "}"

# Replace this dict's source with merged source
self.edit(merged_source)

def unwrap(self) -> None:
"""Unwrap all spread elements in this dictionary.

This will replace all spread elements with their actual key-value pairs.
For example:
{'a': 1, ...dict2, 'b': 2} -> {'a': 1, 'c': 3, 'd': 4, 'b': 2}

Returns:
None
"""
# Process all spread elements
for unpack in list(self.unpacks): # Make a copy since we'll modify during iteration
unpack.unwrap()
Loading
Loading