From 3aadb35df57c1f041219835a49db52be6af7dc7a Mon Sep 17 00:00:00 2001 From: tawsifkamal Date: Tue, 4 Mar 2025 16:34:27 -0800 Subject: [PATCH 1/2] done --- src/codegen/sdk/core/expressions/unpack.py | 16 +- src/codegen/sdk/core/symbol_groups/dict.py | 280 ++++++++++++++++-- .../sdk/typescript/symbol_groups/dict.py | 63 +++- .../sdk/python/expressions/test_dict.py | 260 ++++++++++++++++ .../sdk/typescript/expressions/test_dict.py | 190 ++++++++++++ 5 files changed, 782 insertions(+), 27 deletions(-) diff --git a/src/codegen/sdk/core/expressions/unpack.py b/src/codegen/sdk/core/expressions/unpack.py index 10dd7be52..b6c6b70af 100644 --- a/src/codegen/sdk/core/expressions/unpack.py +++ b/src/codegen/sdk/core/expressions/unpack.py @@ -31,6 +31,7 @@ def unwrap(self, node: Expression | None = None) -> None: Unwraps the content of a node by removing its wrapping syntax and merging its content with its parent node. Specifically handles dictionary unwrapping, maintaining proper indentation and formatting. + Supports multiple spread elements and maintains their order. Args: node (Expression | None): The node to unwrap. If None, uses the instance's value node. @@ -40,7 +41,7 @@ def unwrap(self, node: Expression | None = None) -> None: """ from codegen.sdk.core.symbol_groups.dict import Dict - node = node or self._value_node + node = node or self._value_node.resolved_value if isinstance(node, Dict) and isinstance(self.parent, Dict): if self.start_point[0] != self.parent.start_point[0]: self.remove(delete_formatting=False) @@ -54,10 +55,13 @@ def unwrap(self, node: Expression | None = None) -> None: else: # Delete the remaining characters on this line self.remove_byte_range(self.end_byte, next_sibling.start_byte - next_sibling.start_point[1]) - else: self.remove() - for k, v in node.items(): - self.parent[k] = v.source.strip() - if node.unpack: - self.parent._underlying.append(self.node.unpack.source) + + # Add all items from the unwrapped dictionary + for child in node._underlying: + if isinstance(child, Unpack): + self.parent._underlying.append(child) + self.parent.unpacks.append(child) + else: # Regular key-value pair + self.parent._underlying.append(child) diff --git a/src/codegen/sdk/core/symbol_groups/dict.py b/src/codegen/sdk/core/symbol_groups/dict.py index 20bc3b984..1009086a5 100644 --- a/src/codegen/sdk/core/symbol_groups/dict.py +++ b/src/codegen/sdk/core/symbol_groups/dict.py @@ -1,5 +1,5 @@ from collections.abc import Iterator, MutableMapping -from typing import TYPE_CHECKING, Generic, Self, TypeVar +from typing import TYPE_CHECKING, Generic, Self, TypeVar, overload from tree_sitter import Node as TSNode @@ -86,15 +86,23 @@ class Dict(Expression[Parent], Builtin, MutableMapping[str, TExpression], Generi """ _underlying: Collection[Pair[TExpression, Self] | Unpack[Self], Parent] - unpack: Unpack[Self] | None = None + unpacks: list[Unpack[Self]] = [] # Store all spread elements def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent, delimiter: str = ",", pair_type: type[Pair] = Pair) -> None: - # TODO: handle spread_element super().__init__(ts_node, file_node_id, ctx, parent) - children = [pair_type(child, file_node_id, ctx, self) for child in ts_node.named_children if child.type not in (None, "comment", "spread_element", "dictionary_splat") and not child.is_error] - if unpack := self.child_by_field_types({"spread_element", "dictionary_splat"}): - children.append(unpack) - self.unpack = unpack + children = [] + self.unpacks = [] # Store all spread elements + + for child in ts_node.named_children: + if child.type in (None, "comment") or child.is_error: + continue + if child.type in ("spread_element", "dictionary_splat"): + unpack = Unpack(child, file_node_id, ctx, self) + children.append(unpack) + self.unpacks.append(unpack) # Keep track of all spread elements + else: + children.append(pair_type(child, file_node_id, ctx, self)) + if len(children) > 1: first_child = children[0].ts_node.end_byte - ts_node.start_byte second_child = children[1].ts_node.start_byte - ts_node.start_byte @@ -107,7 +115,14 @@ def __bool__(self) -> bool: def __len__(self) -> int: return len(list(elem for elem in self._underlying if isinstance(elem, Pair))) + def _get_unpacked_items(self) -> Iterator[tuple[str, TExpression]]: + """Get key-value pairs from all unpacked dictionaries.""" + for unpack in self.unpacks: + if isinstance(unpack.value, Dict): + yield from unpack.value.items() + def __iter__(self) -> Iterator[str]: + # First yield keys from regular pairs for pair in self._underlying: if isinstance(pair, Pair): if pair.key is not None: @@ -115,18 +130,31 @@ def __iter__(self) -> Iterator[str]: yield pair.key.content else: yield pair.key.source + # Then yield keys from unpacked dictionaries + for unpack in self.unpacks: + for key, _ in self._get_unpacked_items(): + yield key def __getitem__(self, __key) -> TExpression: - for pair in self._underlying: - if isinstance(pair, Pair): - if isinstance(pair.key, String): - if pair.key.content == str(__key): - return pair.value - elif pair.key is not None: - if pair.key.source == str(__key): - return pair.value - msg = f"Key {__key} not found in {list(self.keys())} {self._underlying!r}" - raise KeyError(msg) + # First try regular pairs + try: + for pair in self._underlying: + if isinstance(pair, Pair): + if isinstance(pair.key, String): + if pair.key.content == str(__key): + return pair.value + elif pair.key is not None: + if pair.key.source == str(__key): + return pair.value + # Then try unpacked dictionaries + for unpack in self.unpacks: + for key, value in self._get_unpacked_items(): + if key == str(__key): + return value + raise KeyError + except KeyError: + msg = f"Key {__key} not found in {list(self.keys())} {self._underlying!r}" + raise KeyError(msg) def __setitem__(self, __key, __value: TExpression) -> None: new_value = __value.source if isinstance(__value, Editable) else str(__value) @@ -178,3 +206,221 @@ def descendant_symbols(self) -> list["Importable"]: @property def __class__(self): return dict + + def __repr__(self) -> str: + """Return a string representation of the dictionary including spread elements.""" + items = [] + + # Add spread elements in their original position + for child in self._underlying: + if isinstance(child, Unpack): + items.append(child.source) + else: # Regular key-value pair + if child.key is not None: + if isinstance(child.key, String): + key = child.key.content + else: + key = child.key.source + items.append(f"{key}: {child.value.source}") + + return "{" + ", ".join(items) + "}" + + def __str__(self) -> str: + return self.__repr__() + + def _get_all_unpacks_and_keys(self, seen_unpacks: set, seen_keys: set) -> None: + """Recursively get all unpacks and their keys from this dictionary and its dependencies. + + Args: + seen_unpacks: Set to store all found unpacks + seen_keys: Set to store all keys from unpacked dictionaries + """ + for child in self._underlying: + if isinstance(child, Unpack): + # Get the name being unpacked (e.g., "base1" from "**base1") + unpack_name = child.source.strip("*") + seen_unpacks.add(unpack_name) + + unpacked_dict = self.file.get_symbol(unpack_name).value + if isinstance(unpacked_dict, Dict): + # Add all keys from this dict + for unpacked_child in unpacked_dict._underlying: + if not isinstance(unpacked_child, Unpack) and unpacked_child.key is not None: + seen_keys.add(unpacked_child.key.source) + # Recursively check its unpacks + unpacked_dict._get_all_unpacks_and_keys(seen_unpacks, seen_keys) + + def _get_unpack_name(self, unpack_source: str) -> str: + """Get the name being unpacked from the source. + + Args: + unpack_source: Source code of the unpack (e.g., "**base1" or "...base1") + + Returns: + Name being unpacked (e.g., "base1") + """ + if unpack_source.startswith("**"): + return unpack_source.strip("*") + elif unpack_source.startswith("..."): + return unpack_source[3:] # Remove the three dots + return unpack_source + + @overload + def merge(self, *others: "Dict[TExpression, Parent]") -> None: ... + + @overload + def merge(self, dict_str: str) -> None: ... + + def merge(self, *others: "Dict[TExpression, Parent] | str") -> None: + """Merge multiple dictionaries into a new dictionary + + Preserves spread operators and function calls in their original form. + Later dictionaries take precedence over earlier ones for duplicate keys. + + Args: + *others: Other Dict objects or a dictionary string to merge. + The string can be either a Python dict (e.g. "{'x': 1}") + or a TypeScript object (e.g. "{x: 1}") + + Raises: + ValueError: If attempting to merge dictionaries with duplicate keys or unpacks + + Returns: + None + """ + # Track seen keys and unpacks to prevent duplicates + seen_keys = set() + seen_unpacks = set() + + # Get all unpacks and their keys from the current dictionary and its dependencies + self._get_all_unpacks_and_keys(seen_unpacks, seen_keys) + + # Keep track of all items in order + merged_items = [] + + # First add all items from this dictionary + for child in self._underlying: + if isinstance(child, Unpack): + unpack_source = child.source + merged_items.append(unpack_source) + else: # Regular key-value pair + if child.key is not None: + key = child.key.source + if key in seen_keys: + msg = f"Duplicate key found: {key}" + raise ValueError(msg) + seen_keys.add(key) + merged_items.append(f"{key}: {child.value.source}") + + # Add items from other dictionaries + for other in others: + if isinstance(other, Dict): + # Handle Dict objects from our SDK + for child in other._underlying: + if isinstance(child, Unpack): + unpack_source = child.source + # Get the name being unpacked (e.g., "base1" from "**base1") + unpack_name = self._get_unpack_name(unpack_source) + if unpack_name in seen_unpacks: + msg = f"Duplicate unpack found: {unpack_source}" + raise ValueError(msg) + seen_unpacks.add(unpack_name) + merged_items.append(unpack_source) + else: # Regular key-value pair + if child.key is not None: + key = child.key.source + if key in seen_keys: + msg = f"Duplicate key found: {key}" + raise ValueError(msg) + seen_keys.add(key) + merged_items.append(f"{key}: {child.value.source}") + elif isinstance(other, str): + # Handle dictionary strings + # Strip curly braces and whitespace + content = other.strip().strip("{}").strip() + if not content: # Skip empty dicts + continue + + # Parse the content to check for duplicates + parts = content.split(",") + for part in parts: + part = part.strip() + if part.startswith("**"): + # Get the name being unpacked (e.g., "base1" from "**base1") + unpack_name = self._get_unpack_name(part) + if unpack_name in seen_unpacks: + msg = f"Duplicate unpack found: {part}" + raise ValueError(msg) + + unpacked_dict = self.file.get_symbol(unpack_name).value + if isinstance(unpacked_dict, Dict): + # Add all keys from this dict + for unpacked_child in unpacked_dict._underlying: + if not isinstance(unpacked_child, Unpack) and unpacked_child.key is not None: + if unpacked_child.key.source in seen_keys: + msg = f"Duplicate key found: {unpacked_child.key.source}" + raise ValueError(msg) + + seen_unpacks.add(unpack_name) + merged_items.append(part) + else: + # It's a key-value pair + key = part.split(":")[0].strip() + if key in seen_keys: + msg = f"Duplicate key found: {key}" + raise ValueError(msg) + seen_keys.add(key) + merged_items.append(part) + else: + msg = f"Cannot merge with object of type {type(other)}" + raise TypeError(msg) + + # Create merged source + merged_source = "{" + ", ".join(merged_items) + "}" + + # Replace this dict's source with merged source + self.edit(merged_source) + + def add(self, typescript_dict: str) -> None: + """Add a TypeScript dictionary string to this dictionary + + Args: + typescript_dict: A TypeScript dictionary string e.g. "{a: 1, b: 2}" + + Returns: + None + """ + # Get current items + merged_items = [] + + # Add all items from this dictionary first + for child in self._underlying: + if isinstance(child, Unpack): + merged_items.append(child.source) + elif child.key is not None: + merged_items.append(f"{child.key.source}: {child.value.source}") + + # Add the TypeScript dictionary content + typescript_dict = typescript_dict.strip().strip("{}").strip() + if typescript_dict: # Only add if not empty + merged_items.append(typescript_dict) + + # Create merged source + merged_source = "{" + ", ".join(merged_items) + "}" + + # Replace this dict's source with merged source + self.edit(merged_source) + + def unwrap(self) -> None: + """Unwrap all spread elements in this dictionary. + + This will replace all spread elements with their actual key-value pairs. + For example: + {'a': 1, ...dict2, 'b': 2} -> {'a': 1, 'c': 3, 'd': 4, 'b': 2} + + Returns: + None + """ + # Process all spread elements + for unpack in list(self.unpacks): # Make a copy since we'll modify during iteration + unpack.unwrap() diff --git a/src/codegen/sdk/typescript/symbol_groups/dict.py b/src/codegen/sdk/typescript/symbol_groups/dict.py index 09fb9ad4d..a5a8323f9 100644 --- a/src/codegen/sdk/typescript/symbol_groups/dict.py +++ b/src/codegen/sdk/typescript/symbol_groups/dict.py @@ -6,8 +6,8 @@ from codegen.sdk.core.autocommit import writer from codegen.sdk.core.expressions import Expression from codegen.sdk.core.expressions.string import String +from codegen.sdk.core.expressions.unpack import Unpack from codegen.sdk.core.interfaces.editable import Editable -from codegen.sdk.core.interfaces.has_attribute import HasAttribute from codegen.sdk.core.node_id_factory import NodeId from codegen.sdk.core.symbol_groups.dict import Dict, Pair from codegen.sdk.extensions.autocommit import reader @@ -71,13 +71,13 @@ def reduce_condition(self, bool_condition: bool, node: Editable | None = None) - @apidoc -class TSDict(Dict, HasAttribute): +class TSDict(Dict[Expression, Parent]): """A typescript dict object. You can use standard operations to operate on this dict (IE len, del, set, get, etc)""" def __init__(self, ts_node: TSNode, file_node_id: NodeId, ctx: "CodebaseContext", parent: Parent, delimiter: str = ",", pair_type: type[Pair] = TSPair) -> None: super().__init__(ts_node, file_node_id, ctx, parent, delimiter=delimiter, pair_type=pair_type) - def __getitem__(self, __key: str) -> TExpression: + def __getitem__(self, __key: str) -> Expression: for pair in self._underlying: pair_match = None @@ -116,7 +116,7 @@ def __setitem__(self, __key: str, __value: TExpression) -> None: if __key == new_value: pair_match.edit(f"{__key}") else: - pair.value.edit(f"{new_value}") + pair_match.value.edit(f"{new_value}") # CASE: {a} else: if __key == new_value: @@ -142,3 +142,58 @@ def __setitem__(self, __key: str, __value: TExpression) -> None: @override def resolve_attribute(self, name: str) -> "Expression | None": return self.get(name, None) + + def merge(self, *others: "Dict[Expression, Parent] | str") -> None: + """Merge multiple dictionaries into a new dictionary. + + Preserves spread operators and function calls in their original form. + Later dictionaries take precedence over earlier ones for duplicate keys. + In TypeScript, duplicate keys and spreads are allowed - later ones override earlier ones. + + Args: + *others: Other Dict objects or a dictionary string to merge. + The string can be either a Python dict (e.g. "{'x': 1}") + or a TypeScript object (e.g. "{x: 1}") + + Returns: + None + """ + # Keep track of all items in order + merged_items = [] + + # First add all items from this dictionary + for child in self._underlying: + if isinstance(child, Unpack): + merged_items.append(child.source) + elif child.key is not None: + merged_items.append(f"{child.key.source}: {child.value.source}") + + # Then add items from other dictionaries + for other in others: + if isinstance(other, Dict): + # Handle Dict objects from our SDK + for child in other._underlying: + if isinstance(child, Unpack): + merged_items.append(child.source) + elif child.key is not None: + merged_items.append(f"{child.key.source}: {child.value.source}") + elif isinstance(other, str): + # Handle dictionary strings + content = other.strip().strip("{}").strip() + if not content: # Skip empty dicts + continue + + # Parse the content + parts = content.split(",") + for part in parts: + part = part.strip() + merged_items.append(part) + else: + msg = f"Cannot merge with object of type {type(other)}" + raise TypeError(msg) + + # Create merged source + merged_source = "{" + ", ".join(merged_items) + "}" + + # Replace this dict's source with merged source + self.edit(merged_source) diff --git a/tests/unit/codegen/sdk/python/expressions/test_dict.py b/tests/unit/codegen/sdk/python/expressions/test_dict.py index 5749ac1c1..eaafcbbd0 100644 --- a/tests/unit/codegen/sdk/python/expressions/test_dict.py +++ b/tests/unit/codegen/sdk/python/expressions/test_dict.py @@ -325,3 +325,263 @@ def test_dict_clear(tmpdir) -> None: symbol = {} """ ) + + +def test_dict_merge(tmpdir) -> None: + """Test merging dictionaries with and without spread operators.""" + file = "test.py" + # language=python + content = """ +dict1 = {'a': 1, 'b': 2} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file(file) + dict1 = file.get_symbol("dict1").value + dict1.merge("{'x': 3, 'y': 4}") + codebase.commit() + assert ( + file.content + == """ +dict1 = {'a': 1, 'b': 2, 'x': 3, 'y': 4} +""" + ) + + +def test_dict_unwrap(tmpdir) -> None: + """Test unwrapping spread operators in dictionaries.""" + file = "test.py" + # language=python + content = """ +base = {'x': 1, 'y': 2} +dict1 = {'a': 1, **base, 'b': 2} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file(file) + dict1 = file.get_symbol("dict1").value + assert len(dict1.unpacks) == 1 + dict1.unwrap() + codebase.commit() + assert ( + file.content + == """ +base = {'x': 1, 'y': 2} +dict1 = {'a': 1, 'b': 2, 'x': 1, 'y': 2} +""" + ) + + +def test_dict_merge_variations(tmpdir) -> None: + """Test various merge scenarios with different dictionary formats.""" + file = "test.py" + # language=python + content = """ +simple = {'a': 1, 'b': 2} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file(file) + simple = file.get_symbol("simple").value + simple.merge("{a: 1, b: 2}") + codebase.commit() + assert ( + file.content + == """ +simple = {'a': 1, 'b': 2, a: 1, b: 2} +""" + ) + + +def test_dict_unwrap_complex(tmpdir) -> None: + """Test unwrapping in various complex scenarios.""" + file = "test.py" + # language=python + content = """ +base1 = {'x': 1, 'y': 2} +dict1 = {**base1, 'z': 3} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file(file) + dict1 = file.get_symbol("dict1").value + dict1.unwrap() + codebase.commit() + assert ( + file.content + == """ +base1 = {'x': 1, 'y': 2} +dict1 = {'z': 3, 'x': 1, 'y': 2} +""" + ) + + +def test_dict_merge_duplicate_unpack(tmpdir) -> None: + """Test that merging with duplicate unpacks raises ValueError.""" + file = "test.py" + # language=python + content = """ +base = {'x': 1, 'y': 2} +simple = {'m': 0, **base} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file(file) + simple = file.get_symbol("simple").value + + # Should raise ValueError because **base already exists in simple + with pytest.raises(ValueError, match="Duplicate unpack found: \\*\\*base"): + simple.merge("{'p': 6, **base}") + codebase.commit() + + +def test_dict_merge_duplicate_keys(tmpdir) -> None: + """Test that merging with duplicate keys raises ValueError.""" + file = "test.py" + # language=python + content = """ +dict1 = {'a': 1, 'b': 2} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file(file) + dict1 = file.get_symbol("dict1").value + + # Should raise ValueError because 'a' already exists + with pytest.raises(ValueError, match="Duplicate key found: 'a'"): + dict1.merge("{'a': 3, 'c': 4}") + codebase.commit() + + +def test_dict_merge_complex(tmpdir) -> None: + """Test complex merge scenarios with multiple dictionaries and spreads.""" + file = "test.py" + # language=python + content = """ +base1 = {'x': 1, 'y': 2} +base2 = {'a': 3, **base1} +base3 = {'c': 4, 'd': 5} +result = {'m': 0, **base2, 'n': 5} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file(file) + result = file.get_symbol("result").value + + # Should work - no duplicate keys or unpacks + result.merge("{'p': 6, 'q': 7}") + codebase.commit() + assert ( + file.content + == """ +base1 = {'x': 1, 'y': 2} +base2 = {'a': 3, **base1} +base3 = {'c': 4, 'd': 5} +result = {'m': 0, **base2, 'n': 5, 'p': 6, 'q': 7} +""" + ) + + # Should fail - trying to merge base1 when it's already unpacked via base2 + with pytest.raises(ValueError, match="Duplicate unpack found: \\*\\*base1"): + result.merge("{'r': 8, **base1}") + + # Should fail - trying to add duplicate key 'a' + with pytest.raises(ValueError, match="Duplicate key found: 'a'"): + result.merge("{'a': 9}") + + +def test_dict_merge_multiple_unpacks(tmpdir) -> None: + """Test merging multiple dictionaries with unpacks.""" + file = "test.py" + # language=python + content = """ +base1 = {'x': 1, 'y': 2} +base2 = {'a': 3, **base1} +base3 = {'c': 4, 'd': 5} +result = {'m': 0, **base2, 'n': 5, 'p': 6, 'q': 7} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file(file) + result = file.get_symbol("result").value + + # Should work - new unique keys and unpacks + result.merge("{'s': 10, **base3}") + codebase.commit() + assert ( + file.content + == """ +base1 = {'x': 1, 'y': 2} +base2 = {'a': 3, **base1} +base3 = {'c': 4, 'd': 5} +result = {'m': 0, **base2, 'n': 5, 'p': 6, 'q': 7, 's': 10, **base3} +""" + ) + + +def test_dict_merge_objects(tmpdir) -> None: + """Test merging Dict objects directly.""" + file = "test.py" + # language=python + content = """ +dict1 = {'x': 1, 'y': 2} +dict2 = {'a': 3, **dict1} +dict3 = {'b': 4, 'c': 5} +dict4 = {'d': 6, **dict3} +result = {'m': 0} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file(file) + result = file.get_symbol("result").value + dict2 = file.get_symbol("dict2").value + dict3 = file.get_symbol("dict3").value + dict4 = file.get_symbol("dict4").value + + # Should work - merging multiple Dict objects + result.merge(dict2, dict3) + codebase.commit() + assert ( + file.content + == """ +dict1 = {'x': 1, 'y': 2} +dict2 = {'a': 3, **dict1} +dict3 = {'b': 4, 'c': 5} +dict4 = {'d': 6, **dict3} +result = {'m': 0, 'a': 3, **dict1, 'b': 4, 'c': 5} +""" + ) + + +def test_dict_unpack_tracking(tmpdir) -> None: + """Test tracking of unpacks in Python dictionaries.""" + file = "test.py" + content = """ +base1 = {'x': 1, 'y': 2} +base2 = {'z': 3, **base1} +dict1 = {'a': 1, **base2, 'b': 2} +dict2 = {'c': 3, **base1, **base2, 'd': 4} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file(file) + + # Test simple unpack + base2 = file.get_symbol("base2").value + assert len(base2.unpacks) == 1 + assert base2.unpacks[0].source == "**base1" + + # Test single unpack with surrounding keys + dict1 = file.get_symbol("dict1").value + assert len(dict1.unpacks) == 1 + assert dict1.unpacks[0].source == "**base2" + + # Test multiple unpacks + dict2 = file.get_symbol("dict2").value + assert len(dict2.unpacks) == 2 + assert dict2.unpacks[0].source == "**base1" + assert dict2.unpacks[1].source == "**base2" + + # Test that unpacks are preserved after merging + dict2.merge("{'e': 5}") + codebase.commit() + assert len(dict2.unpacks) == 2 + assert ( + file.content + == """ +base1 = {'x': 1, 'y': 2} +base2 = {'z': 3, **base1} +dict1 = {'a': 1, **base2, 'b': 2} +dict2 = {'c': 3, **base1, **base2, 'd': 4, 'e': 5} +""" + ) diff --git a/tests/unit/codegen/sdk/typescript/expressions/test_dict.py b/tests/unit/codegen/sdk/typescript/expressions/test_dict.py index 91b97f1ab..d55acdec9 100644 --- a/tests/unit/codegen/sdk/typescript/expressions/test_dict.py +++ b/tests/unit/codegen/sdk/typescript/expressions/test_dict.py @@ -650,3 +650,193 @@ def test_dict_usage_spread(tmpdir): let obj = {1: "a", a: foo()} """ ) + + +def test_dict_merge_basic(tmpdir) -> None: + """Test basic merge functionality in TypeScript dictionaries.""" + file = "test.ts" + # language=typescript + content = """ +const base1 = {x: 1, y: 2} +const base2 = {a: 3, ...base1} +const base3 = {c: 4, d: 5} +const result = {m: 0, ...base2, n: 5} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file(file) + result = file.get_symbol("result").value + + # Should work - adding new keys + result.merge("{p: 6, q: 7}") + codebase.commit() + assert ( + file.content + == """ +const base1 = {x: 1, y: 2} +const base2 = {a: 3, ...base1} +const base3 = {c: 4, d: 5} +const result = {m: 0, ...base2, n: 5, p: 6, q: 7} +""" + ) + + +def test_dict_merge_duplicate_keys(tmpdir) -> None: + """Test merging with duplicate keys in TypeScript dictionaries.""" + file = "test.ts" + # language=typescript + content = """ +const base1 = {x: 1, y: 2} +const base2 = {a: 3, ...base1} +const base3 = {c: 4, d: 5} +const result = {m: 0, ...base2, n: 5} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file(file) + result = file.get_symbol("result").value + + # Should work - duplicate keys are allowed in TypeScript + result.merge("{a: 9}") + codebase.commit() + assert ( + file.content + == """ +const base1 = {x: 1, y: 2} +const base2 = {a: 3, ...base1} +const base3 = {c: 4, d: 5} +const result = {m: 0, ...base2, n: 5, a: 9} +""" + ) + + +def test_dict_merge_multiple_unpacks(tmpdir) -> None: + """Test merging multiple TypeScript dictionaries with unpacks.""" + file = "test.ts" + # language=typescript + content = """ +const base1 = {x: 1, y: 2} +const base2 = {a: 3, ...base1} +const base3 = {c: 4, d: 5} +const result = {m: 0, ...base2, n: 5, p: 6, q: 7} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file(file) + result = file.get_symbol("result").value + + # Should work - new unique keys and unpacks + result.merge("{s: 10, ...base3}") + codebase.commit() + assert ( + file.content + == """ +const base1 = {x: 1, y: 2} +const base2 = {a: 3, ...base1} +const base3 = {c: 4, d: 5} +const result = {m: 0, ...base2, n: 5, p: 6, q: 7, s: 10, ...base3} +""" + ) + + +def test_dict_merge_objects(tmpdir) -> None: + """Test merging TypeScript Dict objects directly.""" + file = "test.ts" + # language=typescript + content = """ +const dict1 = {x: 1, y: 2} +const dict2 = {a: 3, ...dict1} +const dict3 = {b: 4, c: 5} +const dict4 = {d: 6, ...dict3} +const result = {m: 0} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file(file) + result = file.get_symbol("result").value + dict2 = file.get_symbol("dict2").value + dict3 = file.get_symbol("dict3").value + dict4 = file.get_symbol("dict4").value + + # Should work - merging multiple Dict objects + result.merge(dict2, dict3) + codebase.commit() + assert ( + file.content + == """ +const dict1 = {x: 1, y: 2} +const dict2 = {a: 3, ...dict1} +const dict3 = {b: 4, c: 5} +const dict4 = {d: 6, ...dict3} +const result = {m: 0, a: 3, ...dict1, b: 4, c: 5} +""" + ) + + +def test_dict_unwrap_basic(tmpdir) -> None: + """Test basic unwrapping of spread operators in TypeScript dictionaries.""" + file = "test.ts" + # language=typescript + content = """ +const base = {x: 1, y: 2} +const dict1 = {a: 1, ...base, b: 2} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file(file) + dict1 = file.get_symbol("dict1").value + assert len(dict1.unpacks) == 1 + dict1.unwrap() + codebase.commit() + assert ( + file.content + == """ +const base = {x: 1, y: 2} +const dict1 = {a: 1, b: 2, x: 1, y: 2} +""" + ) + + +def test_dict_unwrap_multiple_spreads(tmpdir) -> None: + """Test unwrapping multiple spread operators in TypeScript dictionaries.""" + file = "test.ts" + # language=typescript + content = """ +const base1 = {x: 1, y: 2} +const base2 = {z: 3, w: 4} +const dict1 = {a: 1, ...base1, b: 2, ...base2, c: 3} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file(file) + dict1 = file.get_symbol("dict1").value + assert len(dict1.unpacks) == 2 + dict1.unwrap() + codebase.commit() + assert ( + file.content + == """ +const base1 = {x: 1, y: 2} +const base2 = {z: 3, w: 4} +const dict1 = {a: 1, b: 2, c: 3, x: 1, y: 2, z: 3, w: 4} +""" + ) + + +def test_dict_unwrap_nested_spreads(tmpdir) -> None: + """Test unwrapping nested spread operators in TypeScript dictionaries.""" + file = "test.ts" + # language=typescript + content = """ +const base1 = {x: 1, y: 2} +const base2 = {z: 3, ...base1} +const dict1 = {a: 1, ...base2, b: 2} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file(file) + dict1 = file.get_symbol("dict1").value + assert len(dict1.unpacks) == 1 + dict1.unwrap() + codebase.commit() + assert ( + file.content + == """ +const base1 = {x: 1, y: 2} +const base2 = {z: 3, ...base1} +const dict1 = {a: 1, b: 2, z: 3, ...base1} +""" + ) From bd7c92390a919e969b09f489ce82c17102259d90 Mon Sep 17 00:00:00 2001 From: tawsifkamal Date: Tue, 4 Mar 2025 17:08:11 -0800 Subject: [PATCH 2/2] done --- src/codegen/sdk/core/symbol_groups/dict.py | 37 ++---- .../sdk/typescript/symbol_groups/dict.py | 8 +- .../sdk/python/expressions/test_dict.py | 79 ++++++++++++ .../sdk/typescript/expressions/test_dict.py | 116 ++++++++++++++++++ 4 files changed, 207 insertions(+), 33 deletions(-) diff --git a/src/codegen/sdk/core/symbol_groups/dict.py b/src/codegen/sdk/core/symbol_groups/dict.py index 1009086a5..8310f966e 100644 --- a/src/codegen/sdk/core/symbol_groups/dict.py +++ b/src/codegen/sdk/core/symbol_groups/dict.py @@ -1,5 +1,5 @@ from collections.abc import Iterator, MutableMapping -from typing import TYPE_CHECKING, Generic, Self, TypeVar, overload +from typing import TYPE_CHECKING, Generic, Self, TypeVar from tree_sitter import Node as TSNode @@ -250,27 +250,6 @@ def _get_all_unpacks_and_keys(self, seen_unpacks: set, seen_keys: set) -> None: # Recursively check its unpacks unpacked_dict._get_all_unpacks_and_keys(seen_unpacks, seen_keys) - def _get_unpack_name(self, unpack_source: str) -> str: - """Get the name being unpacked from the source. - - Args: - unpack_source: Source code of the unpack (e.g., "**base1" or "...base1") - - Returns: - Name being unpacked (e.g., "base1") - """ - if unpack_source.startswith("**"): - return unpack_source.strip("*") - elif unpack_source.startswith("..."): - return unpack_source[3:] # Remove the three dots - return unpack_source - - @overload - def merge(self, *others: "Dict[TExpression, Parent]") -> None: ... - - @overload - def merge(self, dict_str: str) -> None: ... - def merge(self, *others: "Dict[TExpression, Parent] | str") -> None: """Merge multiple dictionaries into a new dictionary @@ -278,9 +257,9 @@ def merge(self, *others: "Dict[TExpression, Parent] | str") -> None: Later dictionaries take precedence over earlier ones for duplicate keys. Args: - *others: Other Dict objects or a dictionary string to merge. - The string can be either a Python dict (e.g. "{'x': 1}") - or a TypeScript object (e.g. "{x: 1}") + *others: Other Dict objects or dictionary strings. + The strings can be either Python dicts (e.g. "{'x': 1}") + or TypeScript objects (e.g. "{x: 1}") Raises: ValueError: If attempting to merge dictionaries with duplicate keys or unpacks @@ -292,7 +271,7 @@ def merge(self, *others: "Dict[TExpression, Parent] | str") -> None: seen_keys = set() seen_unpacks = set() - # Get all unpacks and their keys from the current dictionary and its dependencies + # Get all unpacks and their keys from its dependencies self._get_all_unpacks_and_keys(seen_unpacks, seen_keys) # Keep track of all items in order @@ -320,7 +299,7 @@ def merge(self, *others: "Dict[TExpression, Parent] | str") -> None: if isinstance(child, Unpack): unpack_source = child.source # Get the name being unpacked (e.g., "base1" from "**base1") - unpack_name = self._get_unpack_name(unpack_source) + unpack_name = unpack_source.strip("*") if unpack_name in seen_unpacks: msg = f"Duplicate unpack found: {unpack_source}" raise ValueError(msg) @@ -335,7 +314,7 @@ def merge(self, *others: "Dict[TExpression, Parent] | str") -> None: seen_keys.add(key) merged_items.append(f"{key}: {child.value.source}") elif isinstance(other, str): - # Handle dictionary strings + # Handle dictionary string # Strip curly braces and whitespace content = other.strip().strip("{}").strip() if not content: # Skip empty dicts @@ -347,7 +326,7 @@ def merge(self, *others: "Dict[TExpression, Parent] | str") -> None: part = part.strip() if part.startswith("**"): # Get the name being unpacked (e.g., "base1" from "**base1") - unpack_name = self._get_unpack_name(part) + unpack_name = part.strip("*").strip() # Fix unpack name extraction if unpack_name in seen_unpacks: msg = f"Duplicate unpack found: {part}" raise ValueError(msg) diff --git a/src/codegen/sdk/typescript/symbol_groups/dict.py b/src/codegen/sdk/typescript/symbol_groups/dict.py index a5a8323f9..a7a1244fe 100644 --- a/src/codegen/sdk/typescript/symbol_groups/dict.py +++ b/src/codegen/sdk/typescript/symbol_groups/dict.py @@ -151,9 +151,9 @@ def merge(self, *others: "Dict[Expression, Parent] | str") -> None: In TypeScript, duplicate keys and spreads are allowed - later ones override earlier ones. Args: - *others: Other Dict objects or a dictionary string to merge. - The string can be either a Python dict (e.g. "{'x': 1}") - or a TypeScript object (e.g. "{x: 1}") + *others: Other Dict objects or dictionary strings. + The strings can be either Python dicts (e.g. "{'x': 1}") + or TypeScript objects (e.g. "{x: 1}") Returns: None @@ -178,7 +178,7 @@ def merge(self, *others: "Dict[Expression, Parent] | str") -> None: elif child.key is not None: merged_items.append(f"{child.key.source}: {child.value.source}") elif isinstance(other, str): - # Handle dictionary strings + # Handle dictionary string content = other.strip().strip("{}").strip() if not content: # Skip empty dicts continue diff --git a/tests/unit/codegen/sdk/python/expressions/test_dict.py b/tests/unit/codegen/sdk/python/expressions/test_dict.py index eaafcbbd0..9eea4c9ce 100644 --- a/tests/unit/codegen/sdk/python/expressions/test_dict.py +++ b/tests/unit/codegen/sdk/python/expressions/test_dict.py @@ -585,3 +585,82 @@ def test_dict_unpack_tracking(tmpdir) -> None: dict2 = {'c': 3, **base1, **base2, 'd': 4, 'e': 5} """ ) + + +def test_dict_merge_variadic(tmpdir) -> None: + """Test merging multiple dictionaries using variadic arguments.""" + file = "test.py" + content = """ +dict1 = {'a': 1} +dict2 = {'b': 2} +dict3 = {'c': 3} +result = {'m': 0} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file(file) + result = file.get_symbol("result").value + dict2 = file.get_symbol("dict2").value + dict3 = file.get_symbol("dict3").value + + # Test merging multiple Dict objects and strings + result.merge(dict2, dict3, "{'x': 4}", "{'y': 5}") + codebase.commit() + assert ( + file.content + == """ +dict1 = {'a': 1} +dict2 = {'b': 2} +dict3 = {'c': 3} +result = {'m': 0, 'b': 2, 'c': 3, 'x': 4, 'y': 5} +""" + ) + + +def test_dict_merge_variadic_with_unpacks(tmpdir) -> None: + """Test merging multiple dictionaries with unpacks using variadic arguments.""" + file = "test.py" + content = """ +base1 = {'x': 1} +base2 = {'y': 2} +dict1 = {'a': 1, **base1} +dict2 = {'b': 2, **base2} +result = {'m': 0} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file(file) + result = file.get_symbol("result").value + dict1 = file.get_symbol("dict1").value + dict2 = file.get_symbol("dict2").value + + # Test merging multiple Dict objects with unpacks + result.merge(dict1, dict2, "{'z': 3}") + codebase.commit() + assert ( + file.content + == """ +base1 = {'x': 1} +base2 = {'y': 2} +dict1 = {'a': 1, **base1} +dict2 = {'b': 2, **base2} +result = {'m': 0, 'a': 1, **base1, 'b': 2, **base2, 'z': 3} +""" + ) + + +def test_dict_merge_variadic_duplicate_keys(tmpdir) -> None: + """Test merging multiple dictionaries with duplicate keys using variadic arguments.""" + file = "test.py" + content = """ +dict1 = {'a': 1, 'b': 2} +dict2 = {'c': 3, 'd': 4} +result = {'m': 0} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.py": content}) as codebase: + file = codebase.get_file(file) + result = file.get_symbol("result").value + dict1 = file.get_symbol("dict1").value + dict2 = file.get_symbol("dict2").value + + # Should fail - trying to add duplicate key 'a' + with pytest.raises(ValueError, match="Duplicate key found: 'a'"): + result.merge(dict1, dict2, "{'a': 5}") diff --git a/tests/unit/codegen/sdk/typescript/expressions/test_dict.py b/tests/unit/codegen/sdk/typescript/expressions/test_dict.py index d55acdec9..8c556c0a0 100644 --- a/tests/unit/codegen/sdk/typescript/expressions/test_dict.py +++ b/tests/unit/codegen/sdk/typescript/expressions/test_dict.py @@ -840,3 +840,119 @@ def test_dict_unwrap_nested_spreads(tmpdir) -> None: const dict1 = {a: 1, b: 2, z: 3, ...base1} """ ) + + +def test_dict_merge_variadic(tmpdir) -> None: + """Test merging multiple dictionaries using variadic arguments.""" + file = "test.ts" + content = """ +const dict1 = {a: 1} +const dict2 = {b: 2} +const dict3 = {c: 3} +const result = {m: 0} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file(file) + result = file.get_symbol("result").value + dict2 = file.get_symbol("dict2").value + dict3 = file.get_symbol("dict3").value + + # Test merging multiple Dict objects and strings + result.merge(dict2, dict3, "{x: 4}", "{y: 5}") + codebase.commit() + assert ( + file.content + == """ +const dict1 = {a: 1} +const dict2 = {b: 2} +const dict3 = {c: 3} +const result = {m: 0, b: 2, c: 3, x: 4, y: 5} +""" + ) + + +def test_dict_merge_variadic_with_spreads(tmpdir) -> None: + """Test merging multiple dictionaries with spreads using variadic arguments.""" + file = "test.ts" + content = """ +const base1 = {x: 1} +const base2 = {y: 2} +const dict1 = {a: 1, ...base1} +const dict2 = {b: 2, ...base2} +const result = {m: 0} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file(file) + result = file.get_symbol("result").value + dict1 = file.get_symbol("dict1").value + dict2 = file.get_symbol("dict2").value + + # Test merging multiple Dict objects with spreads + result.merge(dict1, dict2, "{z: 3}") + codebase.commit() + assert ( + file.content + == """ +const base1 = {x: 1} +const base2 = {y: 2} +const dict1 = {a: 1, ...base1} +const dict2 = {b: 2, ...base2} +const result = {m: 0, a: 1, ...base1, b: 2, ...base2, z: 3} +""" + ) + + +def test_dict_merge_variadic_duplicate_keys(tmpdir) -> None: + """Test merging multiple dictionaries with duplicate keys using variadic arguments.""" + file = "test.ts" + content = """ +const dict1 = {a: 1, b: 2} +const dict2 = {c: 3, d: 4} +const result = {m: 0} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file(file) + result = file.get_symbol("result").value + dict1 = file.get_symbol("dict1").value + dict2 = file.get_symbol("dict2").value + + # Test merging with duplicate keys - should work in TypeScript + result.merge(dict1, dict2, "{a: 5}") # Duplicate 'a' key is allowed + codebase.commit() + assert ( + file.content + == """ +const dict1 = {a: 1, b: 2} +const dict2 = {c: 3, d: 4} +const result = {m: 0, a: 1, b: 2, c: 3, d: 4, a: 5} +""" + ) + + +def test_dict_merge_variadic_duplicate_spreads(tmpdir) -> None: + """Test merging multiple dictionaries with duplicate spreads using variadic arguments.""" + file = "test.ts" + content = """ +const base1 = {x: 1} +const dict1 = {...base1} +const dict2 = {y: 2} +const result = {m: 0} +""" + with get_codebase_session(tmpdir=tmpdir, files={"test.ts": content}, programming_language=ProgrammingLanguage.TYPESCRIPT) as codebase: + file = codebase.get_file(file) + result = file.get_symbol("result").value + dict1 = file.get_symbol("dict1").value + dict2 = file.get_symbol("dict2").value + + # Test merging with duplicate spreads - should work in TypeScript + result.merge(dict1, dict2, "{...base1}") # Duplicate spread is allowed + codebase.commit() + assert ( + file.content + == """ +const base1 = {x: 1} +const dict1 = {...base1} +const dict2 = {y: 2} +const result = {m: 0, ...base1, y: 2, ...base1} +""" + )