Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions timeplus/base/wire_format.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "../exceptions.h"

#include <algorithm>
#include <stdexcept>

namespace {
Expand Down Expand Up @@ -66,6 +67,79 @@ bool WireFormat::ReadVarint64(InputStream& input, uint64_t* value) {
return false;
}

inline const char* find_quoted_chars(const char* start, const char* end)
{
static constexpr char quoted_chars[] = {'\0', '\b', '\t', '\n', '\'', '\\'};
const auto* first = std::find_first_of(start, end, std::begin(quoted_chars), std::end(quoted_chars));

return (first == end) ? nullptr : first;
}

void WireFormat::WriteQuotedString(OutputStream& output, std::string_view value) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we could have some unit test to cover this function, it will be awesome

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto size = value.size();
const char* start = value.data();
const char* end = start + size;
const char* quoted_char = find_quoted_chars(start, end);
if (quoted_char == nullptr) {
WriteVarint64(output, size + 2);
WriteAll(output, "'", 1);
WriteAll(output, start, size);
WriteAll(output, "'", 1);
return;
}

// calculate quoted chars count
int quoted_count = 1;
const char* next_quoted_char = quoted_char + 1;
while ((next_quoted_char = find_quoted_chars(next_quoted_char, end))) {
quoted_count++;
next_quoted_char++;
}
WriteVarint64(output, size + 2 + 3 * quoted_count); // length

WriteAll(output, "'", 1);

do {
auto write_size = quoted_char - start;
WriteAll(output, start, write_size);
WriteAll(output, "\\", 1);
char c = quoted_char[0];
switch (c) {
case '\0':
WriteAll(output, "x00", 3);
break;
case '\b':
WriteAll(output, "x08", 3);
break;
case '\t':
WriteAll(output, R"(\\t)", 3);
break;
case '\n':
WriteAll(output, R"(\\n)", 3);
break;
case '\'':
WriteAll(output, "x27", 3);
break;
case '\\':
WriteAll(output, R"(\\\)", 3);
break;
default:
break;
}
start = quoted_char + 1;
quoted_char = find_quoted_chars(start, end);
} while (quoted_char);

WriteAll(output, start, end - start);
WriteAll(output, "'", 1);
}

void WireFormat::WriteParamNullRepresentation(OutputStream& output) {
const std::string NULL_REPRESENTATION(R"('\\N')");
WriteVarint64(output, NULL_REPRESENTATION.size());
WriteAll(output, NULL_REPRESENTATION.data(), NULL_REPRESENTATION.size());
}

void WireFormat::WriteVarint64(OutputStream& output, uint64_t value) {
uint8_t bytes[MAX_VARINT_BYTES];
int size = 0;
Expand Down
2 changes: 2 additions & 0 deletions timeplus/base/wire_format.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class WireFormat {
static void WriteFixed(OutputStream& output, const T& value);
static void WriteBytes(OutputStream& output, const void* buf, size_t len);
static void WriteString(OutputStream& output, std::string_view value);
static void WriteQuotedString(OutputStream& output, std::string_view value);
static void WriteParamNullRepresentation(OutputStream& output);
static void WriteUInt64(OutputStream& output, const uint64_t value);
static void WriteVarint64(OutputStream& output, uint64_t value);

Expand Down
34 changes: 33 additions & 1 deletion timeplus/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@
#define DBMS_MIN_REVISION_WITH_DISTRIBUTED_DEPTH 54448
#define DBMS_MIN_REVISION_WITH_INITIAL_QUERY_START_TIME 54449
#define DBMS_MIN_REVISION_WITH_INCREMENTAL_PROFILE_EVENTS 54451
#define DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM 54458
#define DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS 54459

#define DMBS_PROTOCOL_REVISION DBMS_MIN_REVISION_WITH_INCREMENTAL_PROFILE_EVENTS
#define DMBS_PROTOCOL_REVISION DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS

namespace timeplus {

Expand Down Expand Up @@ -176,6 +178,8 @@ class Client::Impl {

bool SendHello();

bool SendAddendum();

bool ReadBlock(InputStream& input, Block* block);

bool ReceiveHello();
Expand Down Expand Up @@ -454,6 +458,9 @@ bool Client::Impl::Handshake() {
if (!ReceiveHello()) {
return false;
}
if (!SendAddendum()) {
return false;
}
return true;
}

Expand Down Expand Up @@ -845,6 +852,19 @@ void Client::Impl::SendQuery(const Query& query) {
WireFormat::WriteUInt64(*output_, compression_);
WireFormat::WriteString(*output_, query.GetText());

// Send params after query text
if (server_info_.revision >= DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS) {
for (const auto& [name, value] : query.GetParams()) {
WireFormat::WriteString(*output_, name);
const uint64_t Custom = 2;
WireFormat::WriteVarint64(*output_, Custom);
if (value)
WireFormat::WriteQuotedString(*output_, *value);
else
WireFormat::WriteParamNullRepresentation(*output_);
}
WireFormat::WriteString(*output_, std::string());
}

// Send empty block as marker of
// end of data
Expand Down Expand Up @@ -924,6 +944,18 @@ bool Client::Impl::SendHello() {
return true;
}

bool Client::Impl::SendAddendum() {
if (server_info_.revision < DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM ||
DMBS_PROTOCOL_REVISION < DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM) {
return true;
}

WireFormat::WriteString(*output_, std::string());
output_->Flush();

return true;
}

bool Client::Impl::ReceiveHello() {
uint64_t packet_type = 0;

Expand Down
15 changes: 15 additions & 0 deletions timeplus/query.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ struct QuerySettingsField {
};

using QuerySettings = std::unordered_map<std::string, QuerySettingsField>;
using QueryParamValue = std::optional<std::string>;
using QueryParams = std::unordered_map<std::string, QueryParamValue>;

struct Profile {
uint64_t rows = 0;
Expand Down Expand Up @@ -114,6 +116,18 @@ class Query : public QueryEvents {
return *this;
}

inline const QueryParams& GetParams() const { return query_params_; }

inline Query& SetParams(QueryParams query_params) {
query_params_ = std::move(query_params);
return *this;
}

inline Query& SetParam(const std::string& name, const QueryParamValue& value) {
query_params_[name] = value;
return *this;
}

inline const std::optional<open_telemetry::TracingContext>& GetTracingContext() const {
return tracing_context_;
}
Expand Down Expand Up @@ -218,6 +232,7 @@ class Query : public QueryEvents {
const std::string query_id_;
std::optional<open_telemetry::TracingContext> tracing_context_;
QuerySettings query_settings_;
QueryParams query_params_;
ExceptionCallback exception_cb_;
ProgressCallback progress_cb_;
SelectCallback select_cb_;
Expand Down
31 changes: 31 additions & 0 deletions ut/stream_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,27 @@
#include <timeplus/base/input.h>

#include <gtest/gtest.h>
#include <string>

using namespace timeplus;

namespace {
std::string roundtripQuotedString(const std::string& value) {
Buffer buf;
{
BufferOutput output(&buf);
WireFormat::WriteQuotedString(output, value);
output.Flush();
}
ArrayInput input(buf.data(), buf.size());
std::string result;
if (!WireFormat::ReadString(input, &result)) {
return {};
}
return result;
}
}

TEST(CodedStreamCase, Varint64) {
Buffer buf;

Expand All @@ -22,3 +40,16 @@ TEST(CodedStreamCase, Varint64) {
ASSERT_EQ(value, 18446744071965638648ULL);
}
}

TEST(CodedStreamCase, QuotedStringPlain) {
ASSERT_EQ(roundtripQuotedString("hello"), "'hello'");
}

TEST(CodedStreamCase, QuotedStringSingleQuote) {
ASSERT_EQ(roundtripQuotedString("a'b"), "'a\\x27b'");
}

TEST(CodedStreamCase, QuotedStringNullByte) {
const std::string value("a\0b", 3);
ASSERT_EQ(roundtripQuotedString(value), "'a\\x00b'");
}
Loading