diff --git a/.gitignore b/.gitignore index 29771bd1..e47f41d2 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,5 @@ tests/benchmark/benchmark-*.png testing-chart.yaml .env.proxy.dev + +/.work diff --git a/Cargo.lock b/Cargo.lock index 11adb87b..dea86e46 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -762,9 +762,9 @@ dependencies = [ [[package]] name = "cipherstash-client" -version = "0.34.0-alpha.4" +version = "0.34.1-alpha.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "200537bf2ab562b085e34df7e3391d0426ab04eea3ed588a7fc27f1bd218ee33" +checksum = "5ed912188c76e36b5e2fc569a9850712e416df2630837e2ca36c32cd48f5488f" dependencies = [ "aes-gcm-siv", "anyhow", @@ -776,7 +776,7 @@ dependencies = [ "blake3", "cfg-if", "chrono", - "cipherstash-config 0.34.0-alpha.4", + "cipherstash-config 0.34.1-alpha.3", "cipherstash-core", "cllw-ore", "cts-common", @@ -794,7 +794,7 @@ dependencies = [ "ore-rs", "percent-encoding", "rand 0.8.5", - "recipher 0.2.0", + "recipher 0.2.1", "reqwest", "reqwest-middleware", "reqwest-retry", @@ -837,20 +837,21 @@ dependencies = [ [[package]] name = "cipherstash-config" -version = "0.34.0-alpha.4" +version = "0.34.1-alpha.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "333ba6c42338ce6bbbc515fb75e43b57311ece1a9ea41e7daabe50478c342841" +checksum = "f6ffe865456a86868f2f544991e3f24986e1eac330e20879a2c132f27d8befdf" dependencies = [ "bitflags", "serde", + "serde_json", "thiserror 1.0.69", ] [[package]] name = "cipherstash-core" -version = "0.34.0-alpha.4" +version = "0.34.1-alpha.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32921e505e39f8f7cae9f55e82462d8dd92764a9148f479b42abf52e60e90437" +checksum = "b696354ee0947f14d2c88ec70300896411d291a3f32015938ccf20bf34f724b3" dependencies = [ "hmac", "lazy_static", @@ -873,6 +874,7 @@ dependencies = [ "bytes", "chrono", "cipherstash-client", + "cipherstash-config 0.34.1-alpha.3", "clap", "config", "cts-common", @@ -1192,9 +1194,9 @@ dependencies = [ [[package]] name = "cts-common" -version = "0.34.0-alpha.4" +version = "0.34.1-alpha.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7817fb03b19c6a588bc9120fd876a6d65f531a0b2aa0d39384bc78f3c4c4340" +checksum = "6b13e6862c6edffe9a8e9c36edb99dacf9e140f217c950f06801c570433bc038" dependencies = [ "arrayvec", "axum", @@ -1215,6 +1217,7 @@ dependencies = [ "regex", "serde", "thiserror 1.0.69", + "tracing", "url", "utoipa", "uuid", @@ -3501,9 +3504,9 @@ dependencies = [ [[package]] name = "recipher" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "061598013445a8bb847d0c95ee33b5e95c1d198d5242b6a8b9f3078aa7437e79" +checksum = "142637705fc952f975681a62dc0772aab3afec9e68955c5539f02cdb711c09aa" dependencies = [ "aes", "async-trait", @@ -4349,9 +4352,9 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "stack-auth" -version = "0.34.0-alpha.4" +version = "0.34.1-alpha.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3e8a681ffc8eb40575fb5f40b8316f1b9e03074eb1e4951e0690b00b0349fed" +checksum = "000f8c46c6e86f119af7fc38c61cd84b8f3b26f29086690a7112b164867c578d" dependencies = [ "aquamarine", "cts-common", @@ -4370,13 +4373,14 @@ dependencies = [ "vitaminc", "vitaminc-protected", "zeroize", + "zerokms-protocol", ] [[package]] name = "stack-profile" -version = "0.34.0" +version = "0.34.1-alpha.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56fdb1e5ef2111e616fb46da39ad63485b3f3c82de3245fe3c14ce52e8775112" +checksum = "d84b6c5cb4759f2d6894c6ce4f3a76acbb4ab8ea592970733730ca76cea9f161" dependencies = [ "dirs", "gethostname", @@ -6286,12 +6290,12 @@ dependencies = [ [[package]] name = "zerokms-protocol" -version = "0.12.3" +version = "0.12.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f52f1d857d2e6d4fe258c49906d53f8b2666c4841dc2e39e67cfea3717382294" +checksum = "7ca10678317fddf3aaead818e0226009925795f75050730413384154db98f98d" dependencies = [ "base64", - "cipherstash-config 0.34.0-alpha.4", + "cipherstash-config 0.34.1-alpha.3", "const-hex", "cts-common", "fake 2.10.0", diff --git a/Cargo.toml b/Cargo.toml index d1a09aec..d081efe9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,8 +43,9 @@ debug = true [workspace.dependencies] sqltk = { version = "0.10.0" } -cipherstash-client = { version = "0.34.0-alpha.4" } -cts-common = { version = "0.34.0-alpha.4" } +cipherstash-client = { version = "=0.34.1-alpha.3" } +cipherstash-config = { version = "=0.34.1-alpha.3" } +cts-common = { version = "=0.34.1-alpha.3" } thiserror = "2.0.9" tokio = { version = "1.44.2", features = ["full"] } diff --git a/docs/plans/2026-04-01-proxy-backwards-compat-tests.md b/docs/plans/2026-04-01-proxy-backwards-compat-tests.md new file mode 100644 index 00000000..888c0505 --- /dev/null +++ b/docs/plans/2026-04-01-proxy-backwards-compat-tests.md @@ -0,0 +1,275 @@ +# Proxy Backwards Compatibility Tests for Canonical Config Migration + +> **For Claude:** REQUIRED SUB-SKILL: Use cipherpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Verify the proxy's integration pipeline from JSON → `CanonicalEncryptionConfig` → `EncryptConfig` works correctly, including Identifier conversion, error handling, and ColumnType mapping. + +**Tech Stack:** Rust, serde_json, cipherstash-config, cipherstash-client + +--- + +### Task 1: Test Identifier bridging in EncryptConfig + +**File:** `packages/cipherstash-proxy/src/proxy/encrypt_config/config.rs` (test module) + +The `load_encrypt_config` function converts `cipherstash_config::Identifier` → `cipherstash_client::eql::Identifier`. Test that this preserves table/column names correctly. + +Add these tests to the existing test module: + +```rust +#[test] +fn config_map_preserves_table_and_column_names() { + let json = json!({ + "v": 1, + "tables": { + "my_schema.users": { + "email_address": { + "cast_as": "text", + "indexes": { "unique": {} } + } + } + } + }); + + let config = parse(json); + + let ident = Identifier::new("my_schema.users", "email_address"); + let column = config.get(&ident).expect("column exists"); + assert_eq!(column.name, "email_address"); + assert_eq!(column.cast_type, ColumnType::Text); +} + +#[test] +fn config_map_handles_multiple_tables() { + let json = json!({ + "v": 1, + "tables": { + "users": { + "email": { "cast_as": "text" } + }, + "orders": { + "total": { "cast_as": "int" } + } + } + }); + + let config = parse(json); + + assert_eq!(config.len(), 2); + assert!(config.contains_key(&Identifier::new("users", "email"))); + assert!(config.contains_key(&Identifier::new("orders", "total"))); +} +``` + +**Verify:** `cargo test -p cipherstash-proxy --lib -- encrypt_config` + +--- + +### Task 2: Test ColumnType mapping through column_type_to_postgres_type + +**File:** `packages/cipherstash-proxy/src/postgresql/context/column.rs` + +The rename from `Utf8Str` → `Text` and `JsonB` → `Json` must produce the same PostgreSQL types. Add tests to verify the mapping. + +Add a test module to `column.rs`: + +```rust +#[cfg(test)] +mod tests { + use super::*; + use eql_mapper::EqlTermVariant; + + #[test] + fn text_column_maps_to_postgres_text() { + assert_eq!( + column_type_to_postgres_type(&ColumnType::Text, EqlTermVariant::Full), + postgres_types::Type::TEXT + ); + } + + #[test] + fn json_column_maps_to_postgres_jsonb() { + assert_eq!( + column_type_to_postgres_type(&ColumnType::Json, EqlTermVariant::Full), + postgres_types::Type::JSONB + ); + } + + #[test] + fn json_accessor_maps_to_postgres_text() { + assert_eq!( + column_type_to_postgres_type(&ColumnType::Json, EqlTermVariant::JsonAccessor), + postgres_types::Type::TEXT + ); + } + + #[test] + fn all_column_types_have_postgres_mapping() { + let types = vec![ + ColumnType::Boolean, + ColumnType::BigInt, + ColumnType::BigUInt, + ColumnType::Date, + ColumnType::Decimal, + ColumnType::Float, + ColumnType::Int, + ColumnType::SmallInt, + ColumnType::Timestamp, + ColumnType::Text, + ColumnType::Json, + ]; + + for ct in types { + // Should not panic + let _ = column_type_to_postgres_type(&ct, EqlTermVariant::Full); + } + } +} +``` + +**Verify:** `cargo test -p cipherstash-proxy --lib -- context::column` + +--- + +### Task 3: Test error propagation for invalid configs + +**File:** `packages/cipherstash-proxy/src/proxy/encrypt_config/config.rs` (test module) + +The canonical `into_config_map()` can now return errors (e.g., ste_vec on non-JSON column). Verify the error surfaces correctly through the proxy's error types. + +```rust +#[test] +fn invalid_config_returns_error() { + let json = json!({ + "v": 1, + "tables": { + "users": { + "email": { + "cast_as": "text", + "indexes": { + "ste_vec": { "prefix": "test" } + } + } + } + } + }); + + let config: CanonicalEncryptionConfig = serde_json::from_value(json).unwrap(); + let result = config.into_config_map(); + assert!(result.is_err(), "ste_vec on text column should fail validation"); +} + +#[test] +fn malformed_json_returns_parse_error() { + let json = json!({ + "v": 1, + "tables": "not a map" + }); + + let result = serde_json::from_value::(json); + assert!(result.is_err()); +} +``` + +**Verify:** `cargo test -p cipherstash-proxy --lib -- encrypt_config` + +--- + +### Task 4: Test real integration schema config shape + +**File:** `packages/cipherstash-proxy/src/proxy/encrypt_config/config.rs` (test module) + +Use the same fixture from the cipherstash-config plan — the JSON shape matching the proxy's integration test schema. Verify the full pipeline including Identifier conversion. + +```rust +#[test] +fn real_eql_config_produces_correct_encrypt_config() { + let json = json!({ + "v": 1, + "tables": { + "encrypted": { + "encrypted_text": { + "cast_as": "text", + "indexes": { "unique": {}, "match": {}, "ore": {} } + }, + "encrypted_bool": { + "cast_as": "boolean", + "indexes": { "unique": {}, "ore": {} } + }, + "encrypted_int2": { + "cast_as": "small_int", + "indexes": { "unique": {}, "ore": {} } + }, + "encrypted_int4": { + "cast_as": "int", + "indexes": { "unique": {}, "ore": {} } + }, + "encrypted_int8": { + "cast_as": "big_int", + "indexes": { "unique": {}, "ore": {} } + }, + "encrypted_float8": { + "cast_as": "double", + "indexes": { "unique": {}, "ore": {} } + }, + "encrypted_date": { + "cast_as": "date", + "indexes": { "unique": {}, "ore": {} } + }, + "encrypted_jsonb": { + "cast_as": "jsonb", + "indexes": { + "ste_vec": { "prefix": "encrypted/encrypted_jsonb" } + } + }, + "encrypted_jsonb_filtered": { + "cast_as": "jsonb", + "indexes": { + "ste_vec": { + "prefix": "encrypted/encrypted_jsonb_filtered", + "term_filters": [{ "kind": "downcase" }] + } + } + } + } + } + }); + + let config = parse(json); + + // All 9 columns present with correct Identifiers + assert_eq!(config.len(), 9); + + // Verify legacy type aliases map correctly + let float_col = config.get(&Identifier::new("encrypted", "encrypted_float8")).unwrap(); + assert_eq!(float_col.cast_type, ColumnType::Float); + + let jsonb_col = config.get(&Identifier::new("encrypted", "encrypted_jsonb")).unwrap(); + assert_eq!(jsonb_col.cast_type, ColumnType::Json); + + // Verify index counts + let text_col = config.get(&Identifier::new("encrypted", "encrypted_text")).unwrap(); + assert_eq!(text_col.indexes.len(), 3); + + let bool_col = config.get(&Identifier::new("encrypted", "encrypted_bool")).unwrap(); + assert_eq!(bool_col.indexes.len(), 2); + + let jsonb_filtered = config.get(&Identifier::new("encrypted", "encrypted_jsonb_filtered")).unwrap(); + assert_eq!(jsonb_filtered.indexes.len(), 1); +} +``` + +**Verify:** `cargo test -p cipherstash-proxy --lib -- encrypt_config` + +--- + +### Task 5: Full verification + +Run the complete test suite: + +```bash +cargo clippy --no-deps --tests --all-features --all-targets -p cipherstash-proxy -- -D warnings +cargo test -p cipherstash-proxy --lib +``` + +All tests must pass, zero clippy warnings. diff --git a/mise.local.example.toml b/mise.local.example.toml index 0a22e27e..cb98df27 100644 --- a/mise.local.example.toml +++ b/mise.local.example.toml @@ -15,7 +15,7 @@ CS_CLIENT_KEY = "client-key" CS_CLIENT_ID = "client-id" # The release of EQL that the proxy tests will use and releases will be built with -CS_EQL_VERSION = "eql-2.2.1" +CS_EQL_VERSION = "eql-2.3.0-pre.2" # TLS variables are required for providing TLS to Proxy's clients. # CS_TLS__TYPE can be either "Path" or "Pem" (case-sensitive). diff --git a/mise.toml b/mise.toml index 516dc382..9585c9a7 100644 --- a/mise.toml +++ b/mise.toml @@ -34,7 +34,7 @@ CS_PROXY__HOST = "host.docker.internal" # Misc DOCKER_CLI_HINTS = "false" # Please don't show us What's Next. -CS_EQL_VERSION = "eql-2.2.1" +CS_EQL_VERSION = "eql-2.3.0-pre.2" [tools] diff --git a/packages/cipherstash-proxy/Cargo.toml b/packages/cipherstash-proxy/Cargo.toml index 077764ff..34abe065 100644 --- a/packages/cipherstash-proxy/Cargo.toml +++ b/packages/cipherstash-proxy/Cargo.toml @@ -12,6 +12,7 @@ arc-swap = "1.7.1" bytes = { version = "1.9", default-features = false } chrono = { version = "0.4.39", features = ["clock"] } cipherstash-client = { workspace = true, features = ["tokio"] } +cipherstash-config = { workspace = true } clap = { version = "4.5.31", features = ["derive", "env"] } config = { version = "0.15", features = [ "async", diff --git a/packages/cipherstash-proxy/src/error.rs b/packages/cipherstash-proxy/src/error.rs index c192cb9d..aee61575 100644 --- a/packages/cipherstash-proxy/src/error.rs +++ b/packages/cipherstash-proxy/src/error.rs @@ -185,6 +185,9 @@ pub enum ConfigError { #[error(transparent)] Parse(#[from] serde_json::Error), + #[error("Invalid encryption configuration: {0}")] + InvalidEncryptionConfig(#[from] cipherstash_config::errors::ConfigError), + #[error("Database schema could not be loaded")] SchemaCouldNotBeLoaded, diff --git a/packages/cipherstash-proxy/src/postgresql/context/column.rs b/packages/cipherstash-proxy/src/postgresql/context/column.rs index 0088f0b3..4fd9c720 100644 --- a/packages/cipherstash-proxy/src/postgresql/context/column.rs +++ b/packages/cipherstash-proxy/src/postgresql/context/column.rs @@ -77,8 +77,60 @@ fn column_type_to_postgres_type( (ColumnType::Int, _) => postgres_types::Type::INT4, (ColumnType::SmallInt, _) => postgres_types::Type::INT2, (ColumnType::Timestamp, _) => postgres_types::Type::TIMESTAMPTZ, - (ColumnType::Utf8Str, _) => postgres_types::Type::TEXT, - (ColumnType::JsonB, EqlTermVariant::JsonAccessor) => postgres_types::Type::TEXT, - (ColumnType::JsonB, _) => postgres_types::Type::JSONB, + (ColumnType::Text, _) => postgres_types::Type::TEXT, + (ColumnType::Json, EqlTermVariant::JsonAccessor) => postgres_types::Type::TEXT, + (ColumnType::Json, _) => postgres_types::Type::JSONB, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use eql_mapper::EqlTermVariant; + + #[test] + fn text_column_maps_to_postgres_text() { + assert_eq!( + column_type_to_postgres_type(&ColumnType::Text, EqlTermVariant::Full), + postgres_types::Type::TEXT + ); + } + + #[test] + fn json_column_maps_to_postgres_jsonb() { + assert_eq!( + column_type_to_postgres_type(&ColumnType::Json, EqlTermVariant::Full), + postgres_types::Type::JSONB + ); + } + + #[test] + fn json_accessor_maps_to_postgres_text() { + assert_eq!( + column_type_to_postgres_type(&ColumnType::Json, EqlTermVariant::JsonAccessor), + postgres_types::Type::TEXT + ); + } + + #[test] + fn all_column_types_have_postgres_mapping() { + let types = vec![ + ColumnType::Boolean, + ColumnType::BigInt, + ColumnType::BigUInt, + ColumnType::Date, + ColumnType::Decimal, + ColumnType::Float, + ColumnType::Int, + ColumnType::SmallInt, + ColumnType::Timestamp, + ColumnType::Text, + ColumnType::Json, + ]; + + for ct in types { + // Should not panic + let _ = column_type_to_postgres_type(&ct, EqlTermVariant::Full); + } } } diff --git a/packages/cipherstash-proxy/src/postgresql/data/from_sql.rs b/packages/cipherstash-proxy/src/postgresql/data/from_sql.rs index 5736bdee..f70e4962 100644 --- a/packages/cipherstash-proxy/src/postgresql/data/from_sql.rs +++ b/packages/cipherstash-proxy/src/postgresql/data/from_sql.rs @@ -115,7 +115,7 @@ pub fn literal_from_sql( /// /// | Input Type | Target Column Type | Result | /// |------------|--------------------|--------| -/// | `Type::INT4` | `ColumnType::Utf8Str` | `Plaintext::Utf8Str` | +/// | `Type::INT4` | `ColumnType::Text` | `Plaintext::Text` | /// | `Type::INT2` | `ColumnType::Int` | `Plaintext::Int` | /// | `Type::INT8` | `ColumnType::Int` | `Error`` | fn text_from_sql( @@ -126,7 +126,7 @@ fn text_from_sql( debug!(target: ENCODING, ?val, ?eql_term, ?col_type); match (eql_term, col_type) { - (EqlTermVariant::Full | EqlTermVariant::Partial, ColumnType::Utf8Str) => { + (EqlTermVariant::Full | EqlTermVariant::Partial, ColumnType::Text) => { Ok(Plaintext::new(val)) } (EqlTermVariant::Full | EqlTermVariant::Partial, ColumnType::Float) => { @@ -168,7 +168,7 @@ fn text_from_sql( } // If JSONB, JSONPATH values are treated as strings - (EqlTermVariant::JsonPath | EqlTermVariant::JsonAccessor, ColumnType::JsonB) => { + (EqlTermVariant::JsonPath | EqlTermVariant::JsonAccessor, ColumnType::Json) => { let val = if val.starts_with("$.") { val.to_string() } else { @@ -176,12 +176,12 @@ fn text_from_sql( }; Ok(Plaintext::new(val)) } - (EqlTermVariant::Full | EqlTermVariant::Partial, ColumnType::JsonB) => { + (EqlTermVariant::Full | EqlTermVariant::Partial, ColumnType::Json) => { serde_json::from_str::(val) .map_err(|_| MappingError::CouldNotParseParameter) .map(Plaintext::new) } - (EqlTermVariant::Tokenized, ColumnType::Utf8Str) => Ok(Plaintext::new(val)), + (EqlTermVariant::Tokenized, ColumnType::Text) => Ok(Plaintext::new(val)), (eql_term, col_type) => Err(MappingError::UnsupportedParameterType { eql_term, @@ -202,7 +202,7 @@ fn binary_from_sql( debug!(target: ENCODING, ?pg_type, ?eql_term, ?col_type); match (eql_term, col_type, pg_type) { - (EqlTermVariant::Full | EqlTermVariant::Partial, ColumnType::Utf8Str, _) => { + (EqlTermVariant::Full | EqlTermVariant::Partial, ColumnType::Text, _) => { parse_bytes_from_sql::(bytes, pg_type).map(Plaintext::new) } (EqlTermVariant::Full | EqlTermVariant::Partial, ColumnType::Boolean, _) => { @@ -253,7 +253,7 @@ fn binary_from_sql( } // If JSONB, JSONPATH values are treated as strings - (EqlTermVariant::JsonPath, ColumnType::JsonB, &Type::JSONPATH) => { + (EqlTermVariant::JsonPath, ColumnType::Json, &Type::JSONPATH) => { parse_bytes_from_sql::(bytes, pg_type).map(|val| { let val = if val.starts_with("$.") { val @@ -263,7 +263,7 @@ fn binary_from_sql( Plaintext::new(val) }) } - (EqlTermVariant::JsonAccessor, ColumnType::JsonB, &Type::TEXT | &Type::VARCHAR) => { + (EqlTermVariant::JsonAccessor, ColumnType::Json, &Type::TEXT | &Type::VARCHAR) => { parse_bytes_from_sql::(bytes, pg_type).map(|val| { let val = if val.starts_with("$.") { val @@ -276,7 +276,7 @@ fn binary_from_sql( // Python psycopg sends JSON/B as BYTEA ( EqlTermVariant::Full | EqlTermVariant::Partial, - ColumnType::JsonB, + ColumnType::Json, &Type::JSON | &Type::JSONB | &Type::BYTEA, ) => parse_bytes_from_sql::(bytes, pg_type).map(Plaintext::new), @@ -356,9 +356,9 @@ fn decimal_from_sql( .ok_or(MappingError::CouldNotParseParameter) .map(Plaintext::new), - ColumnType::Utf8Str => Ok(Plaintext::new(decimal.to_string())), + ColumnType::Text => Ok(Plaintext::new(decimal.to_string())), - ColumnType::JsonB => { + ColumnType::Json => { let val: serde_json::Value = serde_json::from_str(&decimal.to_string()) .map_err(|_| MappingError::CouldNotParseParameter)?; Ok(Plaintext::new(val)) @@ -408,7 +408,7 @@ mod tests { config: ColumnConfig { name: "column".to_owned(), in_place: false, - cast_type: ColumnType::Utf8Str, + cast_type: ColumnType::Text, indexes: vec![], mode: ColumnMode::PlaintextDuplicate, }, diff --git a/packages/cipherstash-proxy/src/postgresql/data/to_sql.rs b/packages/cipherstash-proxy/src/postgresql/data/to_sql.rs index 6bc57d73..1007d266 100644 --- a/packages/cipherstash-proxy/src/postgresql/data/to_sql.rs +++ b/packages/cipherstash-proxy/src/postgresql/data/to_sql.rs @@ -16,7 +16,7 @@ pub fn to_sql(plaintext: &Plaintext, format_code: &FormatCode) -> Result Result { let s = match &plaintext { - Plaintext::Utf8Str(Some(x)) => x.to_string(), + Plaintext::Text(Some(x)) => x.to_string(), Plaintext::Int(Some(x)) => x.to_string(), Plaintext::BigInt(Some(x)) => x.to_string(), Plaintext::BigUInt(Some(x)) => x.to_string(), @@ -26,7 +26,7 @@ fn text_to_sql(plaintext: &Plaintext) -> Result { Plaintext::NaiveDate(Some(x)) => x.to_string(), Plaintext::SmallInt(Some(x)) => x.to_string(), Plaintext::Timestamp(Some(x)) => x.to_string(), - Plaintext::JsonB(Some(x)) => x.to_string(), + Plaintext::Json(Some(x)) => x.to_string(), _ => "".to_string(), }; @@ -44,8 +44,8 @@ fn binary_to_sql(plaintext: &Plaintext) -> Result { Plaintext::NaiveDate(x) => x.to_sql_checked(&Type::DATE, &mut bytes), Plaintext::SmallInt(x) => x.to_sql_checked(&Type::INT2, &mut bytes), Plaintext::Timestamp(x) => x.to_sql_checked(&Type::TIMESTAMPTZ, &mut bytes), - Plaintext::Utf8Str(x) => x.to_sql_checked(&Type::TEXT, &mut bytes), - Plaintext::JsonB(x) => x.to_sql_checked(&Type::JSONB, &mut bytes), + Plaintext::Text(x) => x.to_sql_checked(&Type::TEXT, &mut bytes), + Plaintext::Json(x) => x.to_sql_checked(&Type::JSONB, &mut bytes), Plaintext::Decimal(x) => x.to_sql_checked(&Type::NUMERIC, &mut bytes), // TODO: Implement these Plaintext::BigUInt(_x) => unimplemented!(), diff --git a/packages/cipherstash-proxy/src/postgresql/frontend.rs b/packages/cipherstash-proxy/src/postgresql/frontend.rs index 780360bb..87199677 100644 --- a/packages/cipherstash-proxy/src/postgresql/frontend.rs +++ b/packages/cipherstash-proxy/src/postgresql/frontend.rs @@ -660,7 +660,7 @@ where let encrypted_nodes = typed_statement .literals .iter() - .zip(encrypted_expressions.into_iter()) + .zip(encrypted_expressions) .filter_map(|((_, original_node), en)| en.map(|en| (NodeKey::new(*original_node), en))) .collect::>(); diff --git a/packages/cipherstash-proxy/src/proxy/encrypt_config/config.rs b/packages/cipherstash-proxy/src/proxy/encrypt_config/config.rs deleted file mode 100644 index 41ae3e13..00000000 --- a/packages/cipherstash-proxy/src/proxy/encrypt_config/config.rs +++ /dev/null @@ -1,489 +0,0 @@ -use crate::error::{ConfigError, Error}; -use cipherstash_client::{ - eql::Identifier, - schema::{ - column::{ArrayIndexMode, Index, IndexType, TokenFilter, Tokenizer}, - ColumnConfig, ColumnType, - }, -}; -use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, str::FromStr}; - -#[derive(Debug, Deserialize, Serialize, Clone, Default)] -pub struct ColumnEncryptionConfig { - #[serde(rename = "v")] - pub version: u32, - pub tables: Tables, -} - -#[derive(Debug, Deserialize, Serialize, Clone, Default)] -pub struct Tables(HashMap); - -impl IntoIterator for Tables { - type Item = (String, Table); - type IntoIter = std::collections::hash_map::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() - } -} - -#[derive(Debug, Deserialize, Serialize, Clone, Default)] -pub struct Table(HashMap); - -impl IntoIterator for Table { - type Item = (String, Column); - type IntoIter = std::collections::hash_map::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.0.into_iter() - } -} - -#[derive(Debug, Default, Deserialize, Serialize, Clone, PartialEq)] -pub struct Column { - #[serde(default)] - cast_as: CastAs, - #[serde(default)] - indexes: Indexes, -} - -#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, PartialEq)] -#[serde(rename_all = "snake_case")] -pub enum CastAs { - BigInt, - Boolean, - Date, - Real, - Double, - Int, - SmallInt, - #[default] - Text, - #[serde(rename = "jsonb")] - JsonB, -} - -#[derive(Debug, Deserialize, Serialize, Clone, Default, PartialEq)] -pub struct Indexes { - #[serde(rename = "ore")] - ore_index: Option, - #[serde(rename = "unique")] - unique_index: Option, - #[serde(rename = "match")] - match_index: Option, - #[serde(rename = "ste_vec")] - ste_vec_index: Option, -} - -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] -pub struct OreIndexOpts {} - -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] -pub struct MatchIndexOpts { - #[serde(default = "default_tokenizer")] - tokenizer: Tokenizer, - #[serde(default)] - token_filters: Vec, - #[serde(default = "default_k")] - k: usize, - #[serde(default = "default_m")] - m: usize, - #[serde(default)] - include_original: bool, -} - -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] -pub struct SteVecIndexOpts { - prefix: String, - #[serde(default)] - term_filters: Vec, - #[serde(default = "default_array_index_mode")] - array_index_mode: ArrayIndexMode, -} - -fn default_array_index_mode() -> ArrayIndexMode { - ArrayIndexMode::ALL -} - -fn default_tokenizer() -> Tokenizer { - Tokenizer::Standard -} - -fn default_k() -> usize { - 6 -} - -fn default_m() -> usize { - 2048 -} - -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] -pub struct UniqueIndexOpts { - #[serde(default)] - token_filters: Vec, -} - -impl From for ColumnType { - fn from(value: CastAs) -> Self { - match value { - CastAs::BigInt => ColumnType::BigInt, - CastAs::SmallInt => ColumnType::SmallInt, - CastAs::Int => ColumnType::Int, - CastAs::Boolean => ColumnType::Boolean, - CastAs::Date => ColumnType::Date, - CastAs::Real | CastAs::Double => ColumnType::Float, - CastAs::Text => ColumnType::Utf8Str, - CastAs::JsonB => ColumnType::JsonB, - } - } -} - -impl FromStr for ColumnEncryptionConfig { - type Err = Error; - - fn from_str(data: &str) -> Result { - let config = serde_json::from_str(data).map_err(ConfigError::Parse)?; - Ok(config) - } -} - -impl ColumnEncryptionConfig { - pub fn is_empty(&self) -> bool { - self.tables.0.is_empty() - } - - pub fn into_config_map(self) -> HashMap { - let mut map = HashMap::new(); - for (table_name, columns) in self.tables.into_iter() { - for (column_name, column) in columns.into_iter() { - let column_config = column.into_column_config(&column_name); - let key = Identifier::new(&table_name, &column_name); - map.insert(key, column_config); - } - } - map - } -} - -impl Column { - pub fn into_column_config(self, name: &String) -> ColumnConfig { - let mut config = ColumnConfig::build(name.to_string()).casts_as(self.cast_as.into()); - - if self.indexes.ore_index.is_some() { - config = config.add_index(Index::new_ore()); - } - - if let Some(opts) = self.indexes.match_index { - config = config.add_index(Index::new(IndexType::Match { - tokenizer: opts.tokenizer, - token_filters: opts.token_filters, - k: opts.k, - m: opts.m, - include_original: opts.include_original, - })); - } - - if let Some(opts) = self.indexes.unique_index { - config = config.add_index(Index::new(IndexType::Unique { - token_filters: opts.token_filters, - })) - } - - if let Some(SteVecIndexOpts { - prefix, - term_filters, - array_index_mode, - }) = self.indexes.ste_vec_index - { - config = config.add_index(Index::new(IndexType::SteVec { - prefix, - term_filters, - array_index_mode, - })) - } - - config - } -} - -#[cfg(test)] -mod tests { - use cipherstash_client::eql::Identifier; - use serde_json::json; - - use super::*; - - fn parse(json: serde_json::Value) -> HashMap { - serde_json::from_value::(json) - .map(|config| config.into_config_map()) - .expect("Error ok") - } - - #[test] - fn column_with_empty_options_gets_defaults() { - let json = json!({ - "v": 1, - "tables": { - "users": { - "email": {} - } - } - }); - - let encrypt_config = parse(json); - - let ident = Identifier::new("users", "email"); - - let column = encrypt_config.get(&ident).expect("column exists"); - - assert_eq!(column.cast_type, ColumnType::Utf8Str); - assert!(column.indexes.is_empty()); - } - - #[test] - fn can_parse_column_with_cast_as() { - let json = json!({ - "v": 1, - "tables": { - "users": { - "favourite_int": { - "cast_as": "int" - } - } - } - }); - - let encrypt_config = parse(json); - - let ident = Identifier::new("users", "favourite_int"); - - let column = encrypt_config.get(&ident).expect("column exists"); - - assert_eq!(column.cast_type, ColumnType::Int); - assert_eq!(column.name, "favourite_int"); - assert!(column.indexes.is_empty()); - } - - #[test] - fn can_parse_empty_indexes() { - let json = json!({ - "v": 1, - "tables": { - "users": { - "email": { - "indexes": {} - } - } - } - }); - - let encrypt_config = parse(json); - - let ident = Identifier::new("users", "email"); - - let column = encrypt_config.get(&ident).expect("column exists"); - - assert!(column.indexes.is_empty()); - } - - #[test] - fn can_parse_ore_index() { - let json = json!({ - "v": 1, - "tables": { - "users": { - "email": { - "indexes": { - "ore": {} - } - } - } - } - }); - - let encrypt_config = parse(json); - - let ident = Identifier::new("users", "email"); - - let column = encrypt_config.get(&ident).expect("column exists"); - - assert_eq!(column.indexes[0].index_type, IndexType::Ore); - } - - #[test] - fn can_parse_unique_index_with_defaults() { - let json = json!({ - "v": 1, - "tables": { - "users": { - "email": { - "indexes": { - "unique": {} - } - } - } - } - }); - - let encrypt_config = parse(json); - - let ident = Identifier::new("users", "email"); - - let column = encrypt_config.get(&ident).expect("column exists"); - - assert_eq!( - column.indexes[0].index_type, - IndexType::Unique { - token_filters: vec![] - } - ); - } - - #[test] - fn can_parse_unique_index_with_token_filter() { - let json = json!({ - "v": 1, - "tables": { - "users": { - "email": { - "indexes": { - "unique": { - "token_filters": [ - { - "kind": "downcase" - } - ] - } - } - } - } - } - }); - - let encrypt_config = parse(json); - - let ident = Identifier::new("users", "email"); - - let column = encrypt_config.get(&ident).expect("column exists"); - - assert_eq!( - column.indexes[0].index_type, - IndexType::Unique { - token_filters: vec![TokenFilter::Downcase] - } - ); - } - - #[test] - fn can_parse_match_index_with_defaults() { - let json = json!({ - "v": 1, - "tables": { - "users": { - "email": { - "indexes": { - "match": {} - } - } - } - } - }); - - let encrypt_config = parse(json); - - let ident = Identifier::new("users", "email"); - - let column = encrypt_config.get(&ident).expect("column exists"); - - assert_eq!( - column.indexes[0].index_type, - IndexType::Match { - tokenizer: Tokenizer::Standard, - token_filters: vec![], - k: 6, - m: 2048, - include_original: false - } - ); - } - - #[test] - fn can_parse_match_index_with_all_opts_set() { - let json = json!({ - "v": 1, - "tables": { - "users": { - "email": { - "indexes": { - "match": { - "tokenizer": { - "kind": "ngram", - "token_length": 3, - }, - "token_filters": [ - { - "kind": "downcase" - } - ], - "k": 8, - "m": 1024, - "include_original": true - } - } - } - } - } - }); - - let encrypt_config = parse(json); - - let ident = Identifier::new("users", "email"); - - let column = encrypt_config.get(&ident).expect("column exists"); - - assert_eq!( - column.indexes[0].index_type, - IndexType::Match { - tokenizer: Tokenizer::Ngram { token_length: 3 }, - token_filters: vec![TokenFilter::Downcase], - k: 8, - m: 1024, - include_original: true - } - ); - } - - #[test] - fn can_parse_ste_vec_index() { - let json = json!({ - "v": 1, - "tables": { - "users": { - "event_data": { - "indexes": { - "ste_vec": { - "prefix": "event-data" - } - } - } - } - } - }); - - let encrypt_config = parse(json); - - let ident = Identifier::new("users", "event_data"); - - let column = encrypt_config.get(&ident).expect("column exists"); - - assert_eq!( - column.indexes[0].index_type, - IndexType::SteVec { - prefix: "event-data".into(), - term_filters: vec![], - array_index_mode: ArrayIndexMode::ALL, - }, - ); - } -} diff --git a/packages/cipherstash-proxy/src/proxy/encrypt_config/manager.rs b/packages/cipherstash-proxy/src/proxy/encrypt_config/manager.rs index 5bfc3014..0b33c28e 100644 --- a/packages/cipherstash-proxy/src/proxy/encrypt_config/manager.rs +++ b/packages/cipherstash-proxy/src/proxy/encrypt_config/manager.rs @@ -8,13 +8,12 @@ use crate::{ use arc_swap::ArcSwap; use cipherstash_client::eql; use cipherstash_client::schema::ColumnConfig; +use cipherstash_config::CanonicalEncryptionConfig; use serde_json::Value; use std::{collections::HashMap, sync::Arc, time::Duration}; use tokio::{task::JoinHandle, time}; use tracing::{debug, error, info, warn}; -use super::config::ColumnEncryptionConfig; - /// /// Column configuration keyed by table name and column name /// - key: `{table_name}.{column_name}` @@ -103,6 +102,12 @@ async fn init_reloader(config: DatabaseConfig) -> Result { + error!( + msg = "Invalid Encrypt configuration in database", + error = inner.to_string() + ); + } _ => { error!( msg = "Error loading Encrypt configuration", @@ -213,8 +218,8 @@ pub async fn load_encrypt_config(config: &DatabaseConfig) -> Result bool { let msg = e.to_string(); msg.contains("eql_v2_configuration") && msg.contains("does not exist") } + +fn canonical_to_map(canonical: CanonicalEncryptionConfig) -> Result { + Ok(canonical + .into_config_map()? + .into_iter() + .map(|(id, col)| (eql::Identifier::new(id.table, id.column), col)) + .collect()) +} + +#[cfg(test)] +mod tests { + use super::*; + use cipherstash_client::eql::Identifier; + use cipherstash_config::column::{ArrayIndexMode, IndexType, TokenFilter, Tokenizer}; + use cipherstash_config::ColumnType; + use serde_json::json; + + fn parse(json: serde_json::Value) -> EncryptConfigMap { + let config: CanonicalEncryptionConfig = serde_json::from_value(json).unwrap(); + canonical_to_map(config).unwrap() + } + + #[test] + fn column_with_empty_options_gets_defaults() { + let json = json!({ + "v": 1, + "tables": { "users": { "email": {} } } + }); + + let encrypt_config = parse(json); + let column = encrypt_config + .get(&Identifier::new("users", "email")) + .unwrap(); + + assert_eq!(column.cast_type, ColumnType::Text); + assert!(column.indexes.is_empty()); + } + + #[test] + fn can_parse_column_with_cast_as() { + let json = json!({ + "v": 1, + "tables": { + "users": { "favourite_int": { "cast_as": "int" } } + } + }); + + let encrypt_config = parse(json); + let column = encrypt_config + .get(&Identifier::new("users", "favourite_int")) + .unwrap(); + + assert_eq!(column.cast_type, ColumnType::Int); + assert_eq!(column.name, "favourite_int"); + assert!(column.indexes.is_empty()); + } + + #[test] + fn cast_as_real_maps_to_float() { + let json = json!({ + "v": 1, + "tables": { + "users": { "rating": { "cast_as": "real" } } + } + }); + + let encrypt_config = parse(json); + let column = encrypt_config + .get(&Identifier::new("users", "rating")) + .unwrap(); + + assert_eq!(column.cast_type, ColumnType::Float); + } + + #[test] + fn cast_as_double_maps_to_float() { + let json = json!({ + "v": 1, + "tables": { + "users": { "rating": { "cast_as": "double" } } + } + }); + + let encrypt_config = parse(json); + let column = encrypt_config + .get(&Identifier::new("users", "rating")) + .unwrap(); + + assert_eq!(column.cast_type, ColumnType::Float); + } + + #[test] + fn can_parse_empty_indexes() { + let json = json!({ + "v": 1, + "tables": { + "users": { "email": { "indexes": {} } } + } + }); + + let encrypt_config = parse(json); + let column = encrypt_config + .get(&Identifier::new("users", "email")) + .unwrap(); + + assert!(column.indexes.is_empty()); + } + + #[test] + fn can_parse_ore_index() { + let json = json!({ + "v": 1, + "tables": { + "users": { "email": { "indexes": { "ore": {} } } } + } + }); + + let encrypt_config = parse(json); + let column = encrypt_config + .get(&Identifier::new("users", "email")) + .unwrap(); + + assert_eq!(column.indexes[0].index_type, IndexType::Ore); + } + + #[test] + fn can_parse_unique_index_with_defaults() { + let json = json!({ + "v": 1, + "tables": { + "users": { "email": { "indexes": { "unique": {} } } } + } + }); + + let encrypt_config = parse(json); + let column = encrypt_config + .get(&Identifier::new("users", "email")) + .unwrap(); + + assert_eq!( + column.indexes[0].index_type, + IndexType::Unique { + token_filters: vec![] + } + ); + } + + #[test] + fn can_parse_unique_index_with_token_filter() { + let json = json!({ + "v": 1, + "tables": { + "users": { + "email": { + "indexes": { + "unique": { + "token_filters": [{ "kind": "downcase" }] + } + } + } + } + } + }); + + let encrypt_config = parse(json); + let column = encrypt_config + .get(&Identifier::new("users", "email")) + .unwrap(); + + assert_eq!( + column.indexes[0].index_type, + IndexType::Unique { + token_filters: vec![TokenFilter::Downcase] + } + ); + } + + #[test] + fn can_parse_match_index_with_defaults() { + let json = json!({ + "v": 1, + "tables": { + "users": { "email": { "indexes": { "match": {} } } } + } + }); + + let encrypt_config = parse(json); + let column = encrypt_config + .get(&Identifier::new("users", "email")) + .unwrap(); + + assert_eq!( + column.indexes[0].index_type, + IndexType::Match { + tokenizer: Tokenizer::Standard, + token_filters: vec![], + k: 6, + m: 2048, + include_original: false, + } + ); + } + + #[test] + fn can_parse_match_index_with_all_opts_set() { + let json = json!({ + "v": 1, + "tables": { + "users": { + "email": { + "indexes": { + "match": { + "tokenizer": { "kind": "ngram", "token_length": 3 }, + "token_filters": [{ "kind": "downcase" }], + "k": 8, + "m": 1024, + "include_original": true + } + } + } + } + } + }); + + let encrypt_config = parse(json); + let column = encrypt_config + .get(&Identifier::new("users", "email")) + .unwrap(); + + assert_eq!( + column.indexes[0].index_type, + IndexType::Match { + tokenizer: Tokenizer::Ngram { token_length: 3 }, + token_filters: vec![TokenFilter::Downcase], + k: 8, + m: 1024, + include_original: true, + } + ); + } + + #[test] + fn can_parse_ste_vec_index() { + let json = json!({ + "v": 1, + "tables": { + "users": { + "event_data": { + "cast_as": "jsonb", + "indexes": { "ste_vec": { "prefix": "event-data" } } + } + } + } + }); + + let encrypt_config = parse(json); + let column = encrypt_config + .get(&Identifier::new("users", "event_data")) + .unwrap(); + + assert_eq!( + column.indexes[0].index_type, + IndexType::SteVec { + prefix: "event-data".into(), + term_filters: vec![], + array_index_mode: ArrayIndexMode::ALL, + }, + ); + } + + #[test] + fn config_map_preserves_table_and_column_names() { + let json = json!({ + "v": 1, + "tables": { + "my_schema.users": { + "email_address": { + "cast_as": "text", + "indexes": { "unique": {} } + } + } + } + }); + + let config = parse(json); + let column = config + .get(&Identifier::new("my_schema.users", "email_address")) + .unwrap(); + assert_eq!(column.name, "email_address"); + assert_eq!(column.cast_type, ColumnType::Text); + } + + #[test] + fn config_map_handles_multiple_tables() { + let json = json!({ + "v": 1, + "tables": { + "users": { "email": { "cast_as": "text" } }, + "orders": { "total": { "cast_as": "int" } } + } + }); + + let config = parse(json); + + assert_eq!(config.len(), 2); + assert_eq!( + config + .get(&Identifier::new("users", "email")) + .unwrap() + .cast_type, + ColumnType::Text + ); + assert_eq!( + config + .get(&Identifier::new("orders", "total")) + .unwrap() + .cast_type, + ColumnType::Int + ); + } + + #[test] + fn invalid_config_returns_error() { + let json = json!({ + "v": 1, + "tables": { + "users": { + "email": { + "cast_as": "text", + "indexes": { "ste_vec": { "prefix": "test" } } + } + } + } + }); + + let config: CanonicalEncryptionConfig = serde_json::from_value(json).unwrap(); + assert!(canonical_to_map(config).is_err()); + } + + #[test] + fn real_eql_config_produces_correct_encrypt_config() { + let json = json!({ + "v": 1, + "tables": { + "encrypted": { + "encrypted_text": { + "cast_as": "text", + "indexes": { "unique": {}, "match": {}, "ore": {} } + }, + "encrypted_bool": { + "cast_as": "boolean", + "indexes": { "unique": {}, "ore": {} } + }, + "encrypted_int2": { + "cast_as": "small_int", + "indexes": { "unique": {}, "ore": {} } + }, + "encrypted_int4": { + "cast_as": "int", + "indexes": { "unique": {}, "ore": {} } + }, + "encrypted_int8": { + "cast_as": "big_int", + "indexes": { "unique": {}, "ore": {} } + }, + "encrypted_float8": { + "cast_as": "double", + "indexes": { "unique": {}, "ore": {} } + }, + "encrypted_date": { + "cast_as": "date", + "indexes": { "unique": {}, "ore": {} } + }, + "encrypted_jsonb": { + "cast_as": "jsonb", + "indexes": { + "ste_vec": { "prefix": "encrypted/encrypted_jsonb" } + } + }, + "encrypted_jsonb_filtered": { + "cast_as": "jsonb", + "indexes": { + "ste_vec": { + "prefix": "encrypted/encrypted_jsonb_filtered", + "term_filters": [{ "kind": "downcase" }] + } + } + } + } + } + }); + + let config = parse(json); + + assert_eq!(config.len(), 9); + + assert_eq!( + config + .get(&Identifier::new("encrypted", "encrypted_float8")) + .unwrap() + .cast_type, + ColumnType::Float + ); + assert_eq!( + config + .get(&Identifier::new("encrypted", "encrypted_jsonb")) + .unwrap() + .cast_type, + ColumnType::Json + ); + assert_eq!( + config + .get(&Identifier::new("encrypted", "encrypted_text")) + .unwrap() + .indexes + .len(), + 3 + ); + assert_eq!( + config + .get(&Identifier::new("encrypted", "encrypted_bool")) + .unwrap() + .indexes + .len(), + 2 + ); + assert_eq!( + config + .get(&Identifier::new("encrypted", "encrypted_jsonb_filtered")) + .unwrap() + .indexes + .len(), + 1 + ); + } + + #[test] + fn malformed_json_returns_parse_error() { + let json = json!({ + "v": 1, + "tables": "not a map" + }); + + let result = serde_json::from_value::(json); + assert!(result.is_err()); + } +} diff --git a/packages/cipherstash-proxy/src/proxy/encrypt_config/mod.rs b/packages/cipherstash-proxy/src/proxy/encrypt_config/mod.rs index 57e5563e..de29826e 100644 --- a/packages/cipherstash-proxy/src/proxy/encrypt_config/mod.rs +++ b/packages/cipherstash-proxy/src/proxy/encrypt_config/mod.rs @@ -1,4 +1,3 @@ -mod config; mod manager; pub use manager::{EncryptConfig, EncryptConfigManager}; diff --git a/packages/eql-mapper/src/inference/infer_type_impls/expr.rs b/packages/eql-mapper/src/inference/infer_type_impls/expr.rs index a90f9bae..d2658b99 100644 --- a/packages/eql-mapper/src/inference/infer_type_impls/expr.rs +++ b/packages/eql-mapper/src/inference/infer_type_impls/expr.rs @@ -1,39 +1,36 @@ use crate::{ get_sql_binop_rule, inference::{unifier::Type, InferType, TypeError}, - IdentCase, TypeInferencer, + EqlTrait, IdentCase, TypeInferencer, }; use eql_mapper_macros::trace_infer; use sqltk::parser::ast::{AccessExpr, Array, Expr, Ident, Subscript}; #[trace_infer] impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { - fn infer_exit(&mut self, return_val: &'ast Expr) -> Result<(), TypeError> { - match return_val { + fn infer_exit(&mut self, expr_val: &'ast Expr) -> Result<(), TypeError> { + match expr_val { // Resolve an identifier using the scope, except if it happens to to be the DEFAULT keyword // in which case we resolve it to a fresh type variable. Expr::Identifier(ident) => { // sqltk_parser treats the `DEFAULT` keyword in expression position as an identifier. if IdentCase(ident) == IdentCase(&Ident::new("default")) { - self.unify_node_with_type(return_val, self.fresh_tvar())?; + self.unify_node_with_type(expr_val, self.fresh_tvar())?; } else { - self.unify_node_with_type(return_val, self.resolve_ident(ident)?)?; + self.unify_node_with_type(expr_val, self.resolve_ident(ident)?)?; }; } Expr::CompoundIdentifier(idents) => { - self.unify_node_with_type(return_val, self.resolve_compound_ident(idents)?)?; + self.unify_node_with_type(expr_val, self.resolve_compound_ident(idents)?)?; } Expr::Wildcard(_) => { - self.unify_node_with_type(return_val, self.resolve_wildcard()?)?; + self.unify_node_with_type(expr_val, self.resolve_wildcard()?)?; } Expr::QualifiedWildcard(object_name, _) => { - self.unify_node_with_type( - return_val, - self.resolve_qualified_wildcard(object_name)?, - )?; + self.unify_node_with_type(expr_val, self.resolve_qualified_wildcard(object_name)?)?; } Expr::JsonAccess { .. } => { @@ -51,13 +48,19 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { | Expr::IsUnknown(expr) | Expr::IsNotUnknown(expr) => { self.unify_node_with_type( - return_val, + expr_val, self.unify(self.get_node_type(&**expr), Type::native())?, )?; } Expr::IsDistinctFrom(a, b) | Expr::IsNotDistinctFrom(a, b) => { - self.unify_node_with_type(return_val, Type::native())?; + let ty = self + .unifier + .borrow_mut() + .fresh_bounded_tvar(EqlTrait::Eq.into()); + self.unify_node_with_type(&**a, ty.clone())?; + self.unify_node_with_type(&**b, ty.clone())?; + self.unify_node_with_type(expr_val, Type::native())?; self.unify_nodes(&**a, &**b)?; } @@ -66,7 +69,7 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { list, negated: _, } => { - self.unify_node_with_type(return_val, Type::native())?; + self.unify_node_with_type(expr_val, Type::native())?; self.unify_node_with_type( &**expr, list.iter().try_fold(self.get_node_type(&**expr), |a, b| { @@ -80,7 +83,7 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { subquery, negated: _, } => { - self.unify_node_with_type(return_val, Type::native())?; + self.unify_node_with_type(expr_val, Type::native())?; let ty = Type::projection(&[(self.get_node_type(&**expr), None)]); self.unify_node_with_type(&**subquery, ty)?; } @@ -95,12 +98,18 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { low, high, } => { - self.unify_node_with_type(return_val, Type::native())?; - self.unify_node_with_type(&**high, self.unify_nodes(&**expr, &**low)?)?; + self.unify_node_with_type(expr_val, Type::native())?; + let ty = self + .unifier + .borrow_mut() + .fresh_bounded_tvar(EqlTrait::Ord.into()); + self.unify_node_with_type(&**expr, ty.clone())?; + self.unify_node_with_type(&**low, ty.clone())?; + self.unify_node_with_type(&**high, ty.clone())?; } Expr::BinaryOp { left, op, right } => { - get_sql_binop_rule(op).apply_constraints(self, left, right, return_val)?; + get_sql_binop_rule(op).apply_constraints(self, left, right, expr_val)?; } //customer_name LIKE 'A%'; @@ -118,7 +127,7 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { escape_char: _, any: false, } => { - self.unify_node_with_type(return_val, Type::native())?; + self.unify_node_with_type(expr_val, Type::native())?; self.unify_nodes(&**expr, &**pattern)?; } @@ -134,7 +143,7 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { pattern, escape_char: _, } => { - self.unify_node_with_type(return_val, Type::native())?; + self.unify_node_with_type(expr_val, Type::native())?; self.unify_nodes_with_type(&**expr, &**pattern, Type::native())?; } @@ -153,7 +162,7 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { compare_op: _, right, } => { - self.unify_node_with_type(return_val, Type::native())?; + self.unify_node_with_type(expr_val, Type::native())?; self.unify_nodes(&**left, &**right)?; } @@ -162,14 +171,14 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { | Expr::UnaryOp { expr, .. } | Expr::Convert { expr, .. } | Expr::Cast { expr, .. } => { - self.unify_nodes_with_type(return_val, &**expr, Type::native())?; + self.unify_nodes_with_type(expr_val, &**expr, Type::native())?; } Expr::AtTimeZone { timestamp, time_zone, } => { - self.unify_node_with_type(return_val, Type::native())?; + self.unify_node_with_type(expr_val, Type::native())?; self.unify_node_with_type(&**timestamp, Type::native())?; self.unify_node_with_type(&**time_zone, Type::native())?; } @@ -179,12 +188,12 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { syntax: _, expr, } => { - self.unify_node_with_type(return_val, Type::native())?; + self.unify_node_with_type(expr_val, Type::native())?; self.unify_node_with_type(&**expr, Type::native())?; } Expr::Position { expr, r#in } => { - self.unify_node_with_type(return_val, Type::native())?; + self.unify_node_with_type(expr_val, Type::native())?; self.unify_nodes_with_type(&**expr, &**r#in, Type::native())?; } @@ -195,7 +204,7 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { special: _, shorthand: _, } => { - self.unify_node_with_type(return_val, Type::native())?; + self.unify_node_with_type(expr_val, Type::native())?; self.unify_node_with_type(&**expr, Type::native())?; if let Some(expr) = substring_from { self.unify_node_with_type(&**expr, Type::native())?; @@ -213,7 +222,7 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { trim_what, trim_characters, } => { - self.unify_node_with_type(return_val, Type::native())?; + self.unify_node_with_type(expr_val, Type::native())?; self.unify_node_with_type(&**expr, Type::native())?; if let Some(trim_where) = trim_where { self.unify_node_with_type(trim_where, Type::native())?; @@ -232,7 +241,7 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { overlay_from, overlay_for, } => { - self.unify_node_with_type(return_val, Type::native())?; + self.unify_node_with_type(expr_val, Type::native())?; self.unify_node_with_type(&**expr, Type::native())?; self.unify_node_with_type(&**overlay_what, Type::native())?; self.unify_node_with_type(&**overlay_from, Type::native())?; @@ -242,29 +251,29 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { } Expr::Collate { expr, collation: _ } => { - self.unify_node_with_type(return_val, Type::native())?; + self.unify_node_with_type(expr_val, Type::native())?; self.unify_node_with_type(&**expr, Type::native())?; } // The current `Expr` shares the same type hole as the sub-expression Expr::Nested(expr) => { - self.unify_nodes(return_val, &**expr)?; + self.unify_nodes(expr_val, &**expr)?; } Expr::Value(value) => { - self.unify_nodes(return_val, value)?; + self.unify_nodes(expr_val, value)?; } Expr::TypedString { data_type: _, value: _, } => { - self.unify_node_with_type(return_val, Type::native())?; + self.unify_node_with_type(expr_val, Type::native())?; } // The return type of this function and the return type of this expression must be the same type. Expr::Function(function) => { - self.unify_node_with_type(return_val, self.get_node_type(function))?; + self.unify_node_with_type(expr_val, self.get_node_type(function))?; } // When operand is Some(operand), all conditions must be of the same type as the operand and much support equality @@ -282,7 +291,7 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { Some(operand) => { for cond_when in conditions { self.unify_nodes_with_type( - return_val, + expr_val, &**operand, self.unify_node_with_type(&cond_when.condition, self.fresh_tvar())?, )?; @@ -303,18 +312,18 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { self.unify_node_with_type(else_result, result_ty.clone())?; }; - self.unify_node_with_type(return_val, result_ty)?; + self.unify_node_with_type(expr_val, result_ty)?; } Expr::Exists { subquery: _, negated: _, } => { - self.unify_node_with_type(return_val, Type::native())?; + self.unify_node_with_type(expr_val, Type::native())?; } Expr::Subquery(subquery) => { - self.unify_nodes(return_val, &**subquery)?; + self.unify_nodes(expr_val, &**subquery)?; } // unsupported SQL features @@ -378,7 +387,7 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { } } - self.unify_node_with_type(return_val, access_ty)?; + self.unify_node_with_type(expr_val, access_ty)?; self.unify_node_with_type(&**root, root_ty)?; } @@ -386,12 +395,12 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { // Constrain all elements of the array to be the same type. let elem_ty = self.unify_all_with_type(elem, self.fresh_tvar())?; let array_ty = Type::array(elem_ty); - self.unify_node_with_type(return_val, array_ty)?; + self.unify_node_with_type(expr_val, array_ty)?; } // interval is unmapped, value is unmapped Expr::Interval(interval) => { - self.unify_node_with_type(return_val, Type::native())?; + self.unify_node_with_type(expr_val, Type::native())?; self.unify_node_with_type(&*interval.value, Type::native())?; } diff --git a/packages/eql-mapper/src/inference/registry.rs b/packages/eql-mapper/src/inference/registry.rs index 68ad0617..aed469f0 100644 --- a/packages/eql-mapper/src/inference/registry.rs +++ b/packages/eql-mapper/src/inference/registry.rs @@ -101,7 +101,7 @@ impl<'ast> TypeRegistry<'ast> { .map(|(p, ty)| Param::try_from(*p).map(|p| (p, ty.clone().follow_tvars(unifier)))) .collect::, _>>()?; - params.sort_by(|(a, _), (b, _)| a.cmp(b)); + params.sort_by_key(|(param, _)| *param); Ok(params) } diff --git a/packages/eql-mapper/src/inference/unifier/types.rs b/packages/eql-mapper/src/inference/unifier/types.rs index a40f18ba..8bff0355 100644 --- a/packages/eql-mapper/src/inference/unifier/types.rs +++ b/packages/eql-mapper/src/inference/unifier/types.rs @@ -548,7 +548,7 @@ impl Projection { &*col.ty.clone().follow_tvars(unifier) { let resolved = projection.flatten(unifier)?; - acc.extend(resolved.0.into_iter()); + acc.extend(resolved.0); } else { let ty = col.ty.clone().follow_tvars(unifier); acc.push(ProjectionColumn { ty, alias }); diff --git a/packages/showcase/src/main.rs b/packages/showcase/src/main.rs index 456fa1be..ee14785e 100644 --- a/packages/showcase/src/main.rs +++ b/packages/showcase/src/main.rs @@ -404,33 +404,33 @@ async fn test_complex_nested_queries() -> Result<(), Box> ); // Test 2: Aggregation with JSONB extraction - println!("📝 Test 2: Calculate max risk scores by insurance provider"); - let sql = r#" - SELECT jsonb_path_query_first(p.pii, '$.insurance.provider') as provider, - MAX(jsonb_path_query_first(p.pii, '$.medical_history.risk_factors.cardiovascular')) as max_cv_risk, - COUNT(*) as patient_count - FROM patients AS p - WHERE jsonb_path_exists(p.pii, '$.medical_history.risk_factors.cardiovascular') - GROUP BY jsonb_path_query_first(p.pii, '$.insurance.provider') - ORDER BY MAX(jsonb_path_query_first(p.pii, '$.medical_history.risk_factors.cardiovascular')) DESC - "#; - let rows = client.query(sql, &[]).await?; - println!( - "✅ Calculated risk scores for {} insurance providers", - rows.len() - ); - - for row in &rows { - let provider: Option = row.get("provider"); - let provider: Option<&str> = provider.as_ref().and_then(|v| v.as_str()); - - let avg_risk: Option = row.get("max_cv_risk"); - let avg_risk: Option = avg_risk.as_ref().and_then(|v| v.as_i64()); - - let count: Option = row.get("patient_count"); - - println!(" {provider:?}: Avg CV Risk = {avg_risk:?}, Patients = {count:?}"); - } + println!("📝 Test 2: Calculate max risk scores by insurance provider [DISABLED until EQL reinstates GROUP BY support for JSONB terms]"); + // let sql = r#" + // SELECT jsonb_path_query_first(p.pii, '$.insurance.provider') as provider, + // MAX(jsonb_path_query_first(p.pii, '$.medical_history.risk_factors.cardiovascular')) as max_cv_risk, + // COUNT(*) as patient_count + // FROM patients AS p + // WHERE jsonb_path_exists(p.pii, '$.medical_history.risk_factors.cardiovascular') + // GROUP BY jsonb_path_query_first(p.pii, '$.insurance.provider') + // ORDER BY MAX(jsonb_path_query_first(p.pii, '$.medical_history.risk_factors.cardiovascular')) DESC + // "#; + // let rows = client.query(sql, &[]).await?; + // println!( + // "✅ Calculated risk scores for {} insurance providers", + // rows.len() + // ); + + // for row in &rows { + // let provider: Option = row.get("provider"); + // let provider: Option<&str> = provider.as_ref().and_then(|v| v.as_str()); + + // let avg_risk: Option = row.get("max_cv_risk"); + // let avg_risk: Option = avg_risk.as_ref().and_then(|v| v.as_i64()); + + // let count: Option = row.get("patient_count"); + + // println!(" {provider:?}: Avg CV Risk = {avg_risk:?}, Patients = {count:?}"); + // } // Test 3: Complex filtering with multiple JSONB conditions println!("📝 Test 3: Find patients with allergies AND high deductibles");