diff --git a/changelog.md b/changelog.md index 11de6a03..5b5e68a5 100644 --- a/changelog.md +++ b/changelog.md @@ -1,3 +1,16 @@ +TBD +============== + +Features +-------- +* Options to limit size of LLM prompts; cache LLM prompt data. + + +Bug Fixes +-------- +* Correct mangled schema info sent in LLM prompts. + + 1.50.0 (2026/02/07) ============== diff --git a/mycli/main.py b/mycli/main.py index b86e4a43..d5a3db81 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -169,6 +169,14 @@ def __init__( self.null_string = c['main'].get('null_string') self.numeric_alignment = c['main'].get('numeric_alignment', 'right') self.binary_display = c['main'].get('binary_display') + if 'llm' in c and re.match(r'^\d+$', c['llm'].get('prompt_field_truncate', '')): + self.llm_prompt_field_truncate = int(c['llm'].get('prompt_field_truncate')) + else: + self.llm_prompt_field_truncate = 0 + if 'llm' in c and re.match(r'^\d+$', c['llm'].get('prompt_section_truncate', '')): + self.llm_prompt_section_truncate = int(c['llm'].get('prompt_section_truncate')) + else: + self.llm_prompt_section_truncate = 0 # set ssl_mode if a valid option is provided in a config file, otherwise None ssl_mode = c["main"].get("ssl_mode", None) @@ -965,9 +973,16 @@ def one_iteration(text: str | None = None) -> None: while special.is_llm_command(text): start = time() try: + assert isinstance(self.sqlexecute, SQLExecute) assert sqlexecute.conn is not None cur = sqlexecute.conn.cursor() - context, sql, duration = special.handle_llm(text, cur) + context, sql, duration = special.handle_llm( + text, + cur, + sqlexecute.dbname or '', + self.llm_prompt_field_truncate, + self.llm_prompt_section_truncate, + ) if context: click.echo("LLM Response:") click.echo(context) diff --git a/mycli/myclirc b/mycli/myclirc index 6cd25582..6f1a42d7 100644 --- a/mycli/myclirc +++ b/mycli/myclirc @@ -176,6 +176,17 @@ default_ssl_cipher = # --ssl-verify-server-cert being set default_ssl_verify_server_cert = False +[llm] + +# If set to a positive integer, truncate text/binary fields to that width +# in bytes when sending sample data, to conserve tokens. Suggestion: 1024. +prompt_field_truncate = None + +# If set to a positive integer, attempt to truncate various sections of LLM +# prompt input to that number in bytes, to conserve tokens. Suggestion: +# 1000000. +prompt_section_truncate = None + [keys] # possible values: auto, fzf, reverse_isearch control_r = auto diff --git a/mycli/packages/special/llm.py b/mycli/packages/special/llm.py index e6023e1d..b8dd437d 100644 --- a/mycli/packages/special/llm.py +++ b/mycli/packages/special/llm.py @@ -38,6 +38,10 @@ LLM_TEMPLATE_NAME = "mycli-llm-template" +SCHEMA_DATA_CACHE: dict[str, str] = {} + +SAMPLE_DATA_CACHE: dict[str, dict] = {} + def run_external_cmd( cmd: str, @@ -212,7 +216,13 @@ def cli_commands() -> list[str]: return list(cli.commands.keys()) -def handle_llm(text: str, cur: Cursor) -> tuple[str, str | None, float]: +def handle_llm( + text: str, + cur: Cursor, + dbname: str, + prompt_field_truncate: int, + prompt_section_truncate: int, +) -> tuple[str, str | None, float]: _, verbosity, arg = parse_special_command(text) if not LLM_IMPORTED: output = [(None, None, None, NEED_DEPENDENCIES)] @@ -261,7 +271,13 @@ def handle_llm(text: str, cur: Cursor) -> tuple[str, str | None, float]: try: ensure_mycli_template() start = time() - context, sql = sql_using_llm(cur=cur, question=arg) + context, sql = sql_using_llm( + cur=cur, + question=arg, + dbname=dbname, + prompt_field_truncate=prompt_field_truncate, + prompt_section_truncate=prompt_section_truncate, + ) end = time() if verbosity == Verbosity.SUCCINCT: context = "" @@ -275,51 +291,110 @@ def is_llm_command(command: str) -> bool: return cmd in ("\\llm", "\\ai") -def sql_using_llm( - cur: Cursor | None, - question: str | None = None, -) -> tuple[str, str | None]: - if cur is None: - raise RuntimeError("Connect to a database and try again.") - schema_query = """ - SELECT CONCAT(table_name, '(', GROUP_CONCAT(column_name, ' ', COLUMN_TYPE SEPARATOR ', '),')') +def truncate_list_elements(row: list, prompt_field_truncate: int, prompt_section_truncate: int) -> list: + if not prompt_section_truncate and not prompt_field_truncate: + return row + + width = prompt_field_truncate + while width >= 0: + truncated_row = [x[:width] if isinstance(x, (str, bytes)) else x for x in row] + if prompt_section_truncate: + if sum(sys.getsizeof(x) for x in truncated_row) <= prompt_section_truncate: + break + width -= 100 + else: + break + return truncated_row + + +def truncate_table_lines(table: list[str], prompt_section_truncate: int) -> list[str]: + if not prompt_section_truncate: + return table + + truncated_table = [] + running_sum = 0 + while table and running_sum <= prompt_section_truncate: + line = table.pop(0) + running_sum += sys.getsizeof(line) + truncated_table.append(line) + return truncated_table + + +def get_schema(cur: Cursor, dbname: str, prompt_section_truncate: int) -> str: + if dbname in SCHEMA_DATA_CACHE: + return SCHEMA_DATA_CACHE[dbname] + click.echo("Preparing schema information to feed the LLM") + schema_query = f""" + SELECT CONCAT(table_name, '(', GROUP_CONCAT(column_name, ' ', COLUMN_TYPE SEPARATOR ', '),')') AS `schema` FROM information_schema.columns - WHERE table_schema = DATABASE() + WHERE table_schema = '{dbname}' GROUP BY table_name ORDER BY table_name """ - tables_query = "SHOW TABLES" - sample_row_query = "SELECT * FROM `{table}` LIMIT 1" - click.echo("Preparing schema information to feed the llm") cur.execute(schema_query) - db_schema = "\n".join([row[0] for (row,) in cur.fetchall()]) + db_schema = [row for (row,) in cur.fetchall()] + summary = '\n'.join(truncate_table_lines(db_schema, prompt_section_truncate)) + SCHEMA_DATA_CACHE[dbname] = summary + return summary + + +def get_sample_data( + cur: Cursor, + dbname: str, + prompt_field_truncate: int, + prompt_section_truncate: int, +) -> dict[str, Any]: + if dbname in SAMPLE_DATA_CACHE: + return SAMPLE_DATA_CACHE[dbname] + click.echo("Preparing sample data to feed the LLM") + tables_query = "SHOW TABLES" + sample_row_query = "SELECT * FROM `{dbname}`.`{table}` LIMIT 1" cur.execute(tables_query) sample_data = {} for (table_name,) in cur.fetchall(): try: - cur.execute(sample_row_query.format(table=table_name)) + cur.execute(sample_row_query.format(dbname=dbname, table=table_name)) except Exception: continue cols = [desc[0] for desc in cur.description] row = cur.fetchone() if row is None: continue - sample_data[table_name] = list(zip(cols, row, strict=True)) + sample_data[table_name] = list( + zip(cols, truncate_list_elements(list(row), prompt_field_truncate, prompt_section_truncate), strict=False) + ) + SAMPLE_DATA_CACHE[dbname] = sample_data + return sample_data + + +def sql_using_llm( + cur: Cursor | None, + question: str | None, + dbname: str = '', + prompt_field_truncate: int = 0, + prompt_section_truncate: int = 0, +) -> tuple[str, str | None]: + if cur is None: + raise RuntimeError("Connect to a database and try again.") + if dbname == '': + raise RuntimeError("Choose a schema and try again.") args = [ "--template", LLM_TEMPLATE_NAME, "--param", "db_schema", - db_schema, + get_schema(cur, dbname, prompt_section_truncate), "--param", "sample_data", - sample_data, + get_sample_data(cur, dbname, prompt_field_truncate, prompt_section_truncate), "--param", "question", question, " ", ] - click.echo("Invoking llm command with schema information") + click.echo(args[4]) + click.echo(args[7]) + click.echo("Invoking llm command with schema information and sample data") _, result = run_external_cmd("llm", *args, capture_output=True) click.echo("Received response from the llm command") match = re.search(_SQL_CODE_FENCE, result, re.DOTALL) diff --git a/test/myclirc b/test/myclirc index 9950be0d..ea4e1497 100644 --- a/test/myclirc +++ b/test/myclirc @@ -174,6 +174,17 @@ default_ssl_cipher = # --ssl-verify-server-cert being set default_ssl_verify_server_cert = False +[llm] + +# If set to a positive integer, truncate text/binary fields to that width +# in bytes when sending sample data, to conserve tokens. Suggestion: 1024. +prompt_field_truncate = None + +# If set to a positive integer, attempt to truncate various sections of LLM +# prompt input to that number in bytes, to conserve tokens. Suggestion: +# 1000000. +prompt_section_truncate = None + [keys] # possible values: auto, fzf, reverse_isearch control_r = auto diff --git a/test/test_llm_special.py b/test/test_llm_special.py index a7fa578a..3ba143e9 100644 --- a/test/test_llm_special.py +++ b/test/test_llm_special.py @@ -26,7 +26,7 @@ def test_llm_command_without_args(mock_llm, executor): assert mock_llm is not None test_text = r"\llm" with pytest.raises(FinishIteration) as exc_info: - handle_llm(test_text, executor) + handle_llm(test_text, executor, 'mysql', 0, 0) # Should return usage message when no args provided assert exc_info.value.args[0] == [(None, None, None, USAGE)] @@ -38,7 +38,7 @@ def test_llm_command_with_c_flag(mock_run_cmd, mock_llm, executor): mock_run_cmd.return_value = (0, "Hello, no SQL today.") test_text = r"\llm -c 'Something?'" with pytest.raises(FinishIteration) as exc_info: - handle_llm(test_text, executor) + handle_llm(test_text, executor, 'mysql', 0, 0) # Expect raw output when no SQL fence found assert exc_info.value.args[0] == [(None, None, None, "Hello, no SQL today.")] @@ -51,7 +51,7 @@ def test_llm_command_with_c_flag_and_fenced_sql(mock_run_cmd, mock_llm, executor fenced = f"Here you go:\n```sql\n{sql_text}\n```" mock_run_cmd.return_value = (0, fenced) test_text = r"\llm -c 'Rewrite SQL'" - result, sql, duration = handle_llm(test_text, executor) + result, sql, duration = handle_llm(test_text, executor, 'mysql', 0, 0) # Without verbose, result is empty, sql extracted assert sql == sql_text assert result == "" @@ -64,7 +64,7 @@ def test_llm_command_known_subcommand(mock_run_cmd, mock_llm, executor): # 'models' is a known subcommand test_text = r"\llm models" with pytest.raises(FinishIteration) as exc_info: - handle_llm(test_text, executor) + handle_llm(test_text, executor, 'mysql', 0, 0) mock_run_cmd.assert_called_once_with("llm", "models", restart_cli=False) assert exc_info.value.args[0] is None @@ -74,7 +74,7 @@ def test_llm_command_known_subcommand(mock_run_cmd, mock_llm, executor): def test_llm_command_with_help_flag(mock_run_cmd, mock_llm, executor): test_text = r"\llm --help" with pytest.raises(FinishIteration) as exc_info: - handle_llm(test_text, executor) + handle_llm(test_text, executor, 'mysql', 0, 0) mock_run_cmd.assert_called_once_with("llm", "--help", restart_cli=False) assert exc_info.value.args[0] is None @@ -84,7 +84,7 @@ def test_llm_command_with_help_flag(mock_run_cmd, mock_llm, executor): def test_llm_command_with_install_flag(mock_run_cmd, mock_llm, executor): test_text = r"\llm install openai" with pytest.raises(FinishIteration) as exc_info: - handle_llm(test_text, executor) + handle_llm(test_text, executor, 'mysql', 0, 0) mock_run_cmd.assert_called_once_with("llm", "install", "openai", restart_cli=True) assert exc_info.value.args[0] is None @@ -98,7 +98,7 @@ def test_llm_command_with_prompt(mock_sql_using_llm, mock_ensure_template, mock_ """ mock_sql_using_llm.return_value = ("CTX", "SELECT 1;") test_text = r"\llm prompt 'Test?'" - context, sql, duration = handle_llm(test_text, executor) + context, sql, duration = handle_llm(test_text, executor, 'mysql', 0, 0) mock_ensure_template.assert_called_once() mock_sql_using_llm.assert_called() assert context == "CTX" @@ -115,7 +115,7 @@ def test_llm_command_question_with_context(mock_sql_using_llm, mock_ensure_templ """ mock_sql_using_llm.return_value = ("CTX2", "SELECT 2;") test_text = r"\llm 'Top 10?'" - context, sql, duration = handle_llm(test_text, executor) + context, sql, duration = handle_llm(test_text, executor, 'mysql', 0, 0) mock_ensure_template.assert_called_once() mock_sql_using_llm.assert_called() assert context == "CTX2" @@ -132,7 +132,7 @@ def test_llm_command_question_verbose(mock_sql_using_llm, mock_ensure_template, """ mock_sql_using_llm.return_value = ("NO_CTX", "SELECT 42;") test_text = r"\llm- 'Succinct?'" - context, sql, duration = handle_llm(test_text, executor) + context, sql, duration = handle_llm(test_text, executor, 'mysql', 0, 0) assert context == "" assert sql == "SELECT 42;" assert isinstance(duration, float) @@ -181,7 +181,7 @@ def fetchone(self): sql_text = "SELECT 1, 'abc';" fenced = f"Note\n```sql\n{sql_text}\n```" mock_run_cmd.return_value = (0, fenced) - result, sql = sql_using_llm(dummy_cur, question="dummy") + result, sql = sql_using_llm(dummy_cur, question="dummy", dbname='mysql') assert result == fenced assert sql == sql_text @@ -194,5 +194,5 @@ def test_handle_llm_aliases_without_args(prefix, executor, monkeypatch): monkeypatch.setattr(llm_module, "llm", object()) with pytest.raises(FinishIteration) as exc_info: - handle_llm(prefix, executor) + handle_llm(prefix, executor, 'mysql', 0, 0) assert exc_info.value.args[0] == [(None, None, None, USAGE)]