diff --git a/timeplus/base/wire_format.cpp b/timeplus/base/wire_format.cpp index b042fcf..1e38cf3 100644 --- a/timeplus/base/wire_format.cpp +++ b/timeplus/base/wire_format.cpp @@ -5,6 +5,7 @@ #include "../exceptions.h" +#include #include namespace { @@ -66,6 +67,79 @@ bool WireFormat::ReadVarint64(InputStream& input, uint64_t* value) { return false; } +inline const char* find_quoted_chars(const char* start, const char* end) +{ + static constexpr char quoted_chars[] = {'\0', '\b', '\t', '\n', '\'', '\\'}; + const auto* first = std::find_first_of(start, end, std::begin(quoted_chars), std::end(quoted_chars)); + + return (first == end) ? nullptr : first; +} + +void WireFormat::WriteQuotedString(OutputStream& output, std::string_view value) { + auto size = value.size(); + const char* start = value.data(); + const char* end = start + size; + const char* quoted_char = find_quoted_chars(start, end); + if (quoted_char == nullptr) { + WriteVarint64(output, size + 2); + WriteAll(output, "'", 1); + WriteAll(output, start, size); + WriteAll(output, "'", 1); + return; + } + + // calculate quoted chars count + int quoted_count = 1; + const char* next_quoted_char = quoted_char + 1; + while ((next_quoted_char = find_quoted_chars(next_quoted_char, end))) { + quoted_count++; + next_quoted_char++; + } + WriteVarint64(output, size + 2 + 3 * quoted_count); // length + + WriteAll(output, "'", 1); + + do { + auto write_size = quoted_char - start; + WriteAll(output, start, write_size); + WriteAll(output, "\\", 1); + char c = quoted_char[0]; + switch (c) { + case '\0': + WriteAll(output, "x00", 3); + break; + case '\b': + WriteAll(output, "x08", 3); + break; + case '\t': + WriteAll(output, R"(\\t)", 3); + break; + case '\n': + WriteAll(output, R"(\\n)", 3); + break; + case '\'': + WriteAll(output, "x27", 3); + break; + case '\\': + WriteAll(output, R"(\\\)", 3); + break; + default: + break; + } + start = quoted_char + 1; + quoted_char = find_quoted_chars(start, end); + } while (quoted_char); + + WriteAll(output, start, end - start); + WriteAll(output, "'", 1); +} + +void WireFormat::WriteParamNullRepresentation(OutputStream& output) { + const std::string NULL_REPRESENTATION(R"('\\N')"); + WriteVarint64(output, NULL_REPRESENTATION.size()); + WriteAll(output, NULL_REPRESENTATION.data(), NULL_REPRESENTATION.size()); +} + void WireFormat::WriteVarint64(OutputStream& output, uint64_t value) { uint8_t bytes[MAX_VARINT_BYTES]; int size = 0; diff --git a/timeplus/base/wire_format.h b/timeplus/base/wire_format.h index d88ff12..0029e07 100644 --- a/timeplus/base/wire_format.h +++ b/timeplus/base/wire_format.h @@ -22,6 +22,8 @@ class WireFormat { static void WriteFixed(OutputStream& output, const T& value); static void WriteBytes(OutputStream& output, const void* buf, size_t len); static void WriteString(OutputStream& output, std::string_view value); + static void WriteQuotedString(OutputStream& output, std::string_view value); + static void WriteParamNullRepresentation(OutputStream& output); static void WriteUInt64(OutputStream& output, const uint64_t value); static void WriteVarint64(OutputStream& output, uint64_t value); diff --git a/timeplus/client.cpp b/timeplus/client.cpp index 188143e..4522502 100644 --- a/timeplus/client.cpp +++ b/timeplus/client.cpp @@ -39,8 +39,10 @@ #define DBMS_MIN_REVISION_WITH_DISTRIBUTED_DEPTH 54448 #define DBMS_MIN_REVISION_WITH_INITIAL_QUERY_START_TIME 54449 #define DBMS_MIN_REVISION_WITH_INCREMENTAL_PROFILE_EVENTS 54451 +#define DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM 54458 +#define DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS 54459 -#define DMBS_PROTOCOL_REVISION DBMS_MIN_REVISION_WITH_INCREMENTAL_PROFILE_EVENTS +#define DMBS_PROTOCOL_REVISION DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS namespace timeplus { @@ -176,6 +178,8 @@ class Client::Impl { bool SendHello(); + bool SendAddendum(); + bool ReadBlock(InputStream& input, Block* block); bool ReceiveHello(); @@ -454,6 +458,9 @@ bool Client::Impl::Handshake() { if (!ReceiveHello()) { return false; } + if (!SendAddendum()) { + return false; + } return true; } @@ -845,6 +852,19 @@ void Client::Impl::SendQuery(const Query& query) { WireFormat::WriteUInt64(*output_, compression_); WireFormat::WriteString(*output_, query.GetText()); + // Send params after query text + if (server_info_.revision >= DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS) { + for (const auto& [name, value] : query.GetParams()) { + WireFormat::WriteString(*output_, name); + const uint64_t Custom = 2; + WireFormat::WriteVarint64(*output_, Custom); + if (value) + WireFormat::WriteQuotedString(*output_, *value); + else + WireFormat::WriteParamNullRepresentation(*output_); + } + WireFormat::WriteString(*output_, std::string()); + } // Send empty block as marker of // end of data @@ -924,6 +944,18 @@ bool Client::Impl::SendHello() { return true; } +bool Client::Impl::SendAddendum() { + if (server_info_.revision < DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM || + DMBS_PROTOCOL_REVISION < DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM) { + return true; + } + + WireFormat::WriteString(*output_, std::string()); + output_->Flush(); + + return true; +} + bool Client::Impl::ReceiveHello() { uint64_t packet_type = 0; diff --git a/timeplus/query.h b/timeplus/query.h index 3f76779..6392ac3 100644 --- a/timeplus/query.h +++ b/timeplus/query.h @@ -26,6 +26,8 @@ struct QuerySettingsField { }; using QuerySettings = std::unordered_map; +using QueryParamValue = std::optional; +using QueryParams = std::unordered_map; struct Profile { uint64_t rows = 0; @@ -114,6 +116,18 @@ class Query : public QueryEvents { return *this; } + inline const QueryParams& GetParams() const { return query_params_; } + + inline Query& SetParams(QueryParams query_params) { + query_params_ = std::move(query_params); + return *this; + } + + inline Query& SetParam(const std::string& name, const QueryParamValue& value) { + query_params_[name] = value; + return *this; + } + inline const std::optional& GetTracingContext() const { return tracing_context_; } @@ -218,6 +232,7 @@ class Query : public QueryEvents { const std::string query_id_; std::optional tracing_context_; QuerySettings query_settings_; + QueryParams query_params_; ExceptionCallback exception_cb_; ProgressCallback progress_cb_; SelectCallback select_cb_; diff --git a/ut/stream_ut.cpp b/ut/stream_ut.cpp index 51cd48d..36fe04e 100644 --- a/ut/stream_ut.cpp +++ b/ut/stream_ut.cpp @@ -3,9 +3,27 @@ #include #include +#include using namespace timeplus; +namespace { + std::string roundtripQuotedString(const std::string& value) { + Buffer buf; + { + BufferOutput output(&buf); + WireFormat::WriteQuotedString(output, value); + output.Flush(); + } + ArrayInput input(buf.data(), buf.size()); + std::string result; + if (!WireFormat::ReadString(input, &result)) { + return {}; + } + return result; + } +} + TEST(CodedStreamCase, Varint64) { Buffer buf; @@ -22,3 +40,16 @@ TEST(CodedStreamCase, Varint64) { ASSERT_EQ(value, 18446744071965638648ULL); } } + +TEST(CodedStreamCase, QuotedStringPlain) { + ASSERT_EQ(roundtripQuotedString("hello"), "'hello'"); +} + +TEST(CodedStreamCase, QuotedStringSingleQuote) { + ASSERT_EQ(roundtripQuotedString("a'b"), "'a\\x27b'"); +} + +TEST(CodedStreamCase, QuotedStringNullByte) { + const std::string value("a\0b", 3); + ASSERT_EQ(roundtripQuotedString(value), "'a\\x00b'"); +}