diff --git a/topic/src/main/java/tech/ydb/topic/impl/TopicRetryableStream.java b/topic/src/main/java/tech/ydb/topic/impl/TopicRetryableStream.java index 18c3acf57..3cdb97368 100644 --- a/topic/src/main/java/tech/ydb/topic/impl/TopicRetryableStream.java +++ b/topic/src/main/java/tech/ydb/topic/impl/TopicRetryableStream.java @@ -33,7 +33,6 @@ public TopicRetryableStream(Logger logger, String debugId, RetryConfig config, S } protected abstract TopicStream createNewStream(String debugId); - protected abstract W getInitRequest(); protected abstract void onNext(R message); @@ -53,7 +52,7 @@ public void start() { return; } - stream.start(getInitRequest(), this::onNext).whenComplete((status, th) -> { + stream.start(this::onNext).whenComplete((status, th) -> { realStream.compareAndSet(stream, null); if (status != null) { onStreamStop(status, retryConfig.getStatusRetryPolicy(status)); diff --git a/topic/src/main/java/tech/ydb/topic/impl/TopicStream.java b/topic/src/main/java/tech/ydb/topic/impl/TopicStream.java index d435693ad..75236073b 100644 --- a/topic/src/main/java/tech/ydb/topic/impl/TopicStream.java +++ b/topic/src/main/java/tech/ydb/topic/impl/TopicStream.java @@ -8,8 +8,8 @@ import tech.ydb.core.Status; public interface TopicStream { - CompletableFuture start(W initReq, Consumer messageHandler); - void send(W request); + CompletableFuture start(Consumer messageHandler); + void send(W request); void close(); } diff --git a/topic/src/main/java/tech/ydb/topic/impl/TopicStreamBase.java b/topic/src/main/java/tech/ydb/topic/impl/TopicStreamBase.java index 520bca53e..f091afe7f 100644 --- a/topic/src/main/java/tech/ydb/topic/impl/TopicStreamBase.java +++ b/topic/src/main/java/tech/ydb/topic/impl/TopicStreamBase.java @@ -15,13 +15,15 @@ public abstract class TopicStreamBase impl private final Logger logger; private final String debugId; private final GrpcReadWriteStream stream; + private final W initRequest; private final CompletableFuture streamStatus = new CompletableFuture<>(); private volatile String token; - public TopicStreamBase(Logger logger, String debugId, GrpcReadWriteStream stream) { + public TopicStreamBase(Logger logger, String debugId, GrpcReadWriteStream stream, W initRequest) { this.logger = logger; this.debugId = debugId; this.stream = stream; + this.initRequest = initRequest; this.token = stream.authToken(); } @@ -29,7 +31,7 @@ public TopicStreamBase(Logger logger, String debugId, GrpcReadWriteStream protected abstract Status parseMessageStatus(R message); @Override - public CompletableFuture start(W initReq, Consumer messageHandler) { + public CompletableFuture start(Consumer messageHandler) { this.logger.debug("[{}] is about to start", debugId); this.stream.start((R msg) -> { Status messageStatus = parseMessageStatus(msg); @@ -48,7 +50,7 @@ public CompletableFuture start(W initReq, Consumer messageHandler) { }); if (!streamStatus.isDone()) { - stream.sendNext(initReq); + stream.sendNext(initRequest); } return streamStatus; diff --git a/topic/src/main/java/tech/ydb/topic/impl/TopicStreamFail.java b/topic/src/main/java/tech/ydb/topic/impl/TopicStreamFail.java index 20ef759a4..3d1965d7a 100644 --- a/topic/src/main/java/tech/ydb/topic/impl/TopicStreamFail.java +++ b/topic/src/main/java/tech/ydb/topic/impl/TopicStreamFail.java @@ -21,7 +21,7 @@ public TopicStreamFail(Logger logger, String debugId, Status status) { } @Override - public CompletableFuture start(W initReq, Consumer messageHandler) { + public CompletableFuture start(Consumer messageHandler) { return CompletableFuture.completedFuture(status); } diff --git a/topic/src/main/java/tech/ydb/topic/write/impl/WriteSession.java b/topic/src/main/java/tech/ydb/topic/write/impl/WriteSession.java index f2a5a7d9a..ee716e43b 100644 --- a/topic/src/main/java/tech/ydb/topic/write/impl/WriteSession.java +++ b/topic/src/main/java/tech/ydb/topic/write/impl/WriteSession.java @@ -1,5 +1,6 @@ package tech.ydb.topic.write.impl; import java.util.List; +import java.util.concurrent.ScheduledExecutorService; import java.util.function.BiConsumer; import org.slf4j.Logger; @@ -10,7 +11,6 @@ import tech.ydb.proto.topic.YdbTopic; import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromClient; import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromServer; -import tech.ydb.topic.TopicRpc; import tech.ydb.topic.impl.TopicRetryableStream; import tech.ydb.topic.impl.TopicStream; import tech.ydb.topic.settings.WriterSettings; @@ -38,10 +38,11 @@ public interface Listener { private final MessageSender sender; private final BiConsumer errorsHandler; - public WriteSession(String debugId, TopicRpc rpc, WriterSettings settings, Listener controller) { - super(logger, debugId, settings.getRetryConfig(), rpc.getScheduler()); + public WriteSession(String debugId, WriteStreamFactory factory, WriterSettings settings, + ScheduledExecutorService scheduler, Listener controller) { + super(logger, debugId, settings.getRetryConfig(), scheduler); this.listener = controller; - this.streamFactory = WriteStreamFactory.of(rpc, settings); + this.streamFactory = factory; this.sender = new MessageSender(debugId, settings.getCodec(), this::send); this.errorsHandler = settings.getErrorsHandler(); } @@ -51,11 +52,6 @@ protected Stream createNewStream(String id) { return streamFactory.createNewStream(id); } - @Override - protected FromClient getInitRequest() { - return streamFactory.initRequest(); - } - public void sendAll(List list) { for (SentMessage msg: list) { sender.sendMessage(msg); diff --git a/topic/src/main/java/tech/ydb/topic/write/impl/WriteStream.java b/topic/src/main/java/tech/ydb/topic/write/impl/WriteStream.java index 10ccad609..3f0b1ae89 100644 --- a/topic/src/main/java/tech/ydb/topic/write/impl/WriteStream.java +++ b/topic/src/main/java/tech/ydb/topic/write/impl/WriteStream.java @@ -20,8 +20,8 @@ public class WriteStream extends TopicStreamBase implements WriteSession.Stream { private static final Logger logger = LoggerFactory.getLogger(WriteStream.class); - public WriteStream(String id, GrpcReadWriteStream stream) { - super(logger, id, stream); + public WriteStream(String id, GrpcReadWriteStream stream, FromClient initRequest) { + super(logger, id, stream, initRequest); } @Override diff --git a/topic/src/main/java/tech/ydb/topic/write/impl/WriteStreamDirectFactory.java b/topic/src/main/java/tech/ydb/topic/write/impl/WriteStreamDirectFactory.java new file mode 100644 index 000000000..c3cf7e425 --- /dev/null +++ b/topic/src/main/java/tech/ydb/topic/write/impl/WriteStreamDirectFactory.java @@ -0,0 +1,166 @@ +package tech.ydb.topic.write.impl; + +import java.time.Duration; +import java.util.concurrent.CompletableFuture; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import tech.ydb.core.Issue; +import tech.ydb.core.Result; +import tech.ydb.core.Status; +import tech.ydb.core.StatusCode; +import tech.ydb.core.grpc.GrpcReadWriteStream; +import tech.ydb.core.grpc.GrpcRequestSettings; +import tech.ydb.proto.StatusCodesProtos; +import tech.ydb.proto.topic.YdbTopic; +import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage; +import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromClient; +import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromServer; +import tech.ydb.topic.TopicRpc; +import tech.ydb.topic.settings.WriterSettings; + +/** + * + * @author Aleksandr Gorshenin {@literal } + */ +public class WriteStreamDirectFactory extends WriteStreamFactory { + private static final Logger logger = LoggerFactory.getLogger(WriteStreamDirectFactory.class); + + public WriteStreamDirectFactory(TopicRpc rpc, WriterSettings settings) { + super(rpc, settings); + + if (settings.getPartitionId() == null && settings.getProducerId() == null) { + throw new IllegalArgumentException("Direct writing requires PartitionId or ProducerId in WriterSettings"); + } + } + + @Override + public WriteSession.Stream createNewStream(String id) { + Long targetPartitionId = partitionId; + if (targetPartitionId == null) { + Result pid = lookupPartitionId(id); + if (!pid.isSuccess()) { + return new WriteStream.Fail(id, pid.getStatus()); + } + targetPartitionId = pid.getValue(); + } + + Result location = lookupLocation(id, targetPartitionId); + if (!location.isSuccess()) { + return new WriteStream.Fail(id, location.getStatus()); + } + + StreamWriteMessage.InitRequest.Builder req = StreamWriteMessage.InitRequest.newBuilder() + .setPath(topicPath) + .setPartitionWithGeneration(YdbTopic.PartitionWithGeneration.newBuilder() + .setPartitionId(targetPartitionId) + .setGeneration(location.getValue().getGeneration()) + .build()); + + if (producerId != null) { + req.setProducerId(producerId); + } + + FromClient init = FromClient.newBuilder().setInitRequest(req.build()).build(); + GrpcRequestSettings settings = GrpcRequestSettings.newBuilder() + .withTraceId(id) + .disableDeadline() + .withDirectMode(true) + .withPreferredNodeID(location.getValue().getNodeId()) + .build(); + + return new WriteStream(id, rpc.writeSession(settings), init); + } + + protected Result lookupLocation(String id, long targetPartitionId) { + logger.info("[{}] describe topic {} to look up node for partition {}", id, topicPath, targetPartitionId); + Result describeTopic = rpc.describeTopic( + YdbTopic.DescribeTopicRequest.newBuilder().setIncludeLocation(true).setPath(topicPath).build(), + GrpcRequestSettings.newBuilder().withDeadline(Duration.ofMinutes(1)).build() + ).join(); + + if (!describeTopic.isSuccess()) { + logger.warn("[{}] describe topic {} failed with status {}", id, topicPath, describeTopic.getStatus()); + return Result.fail(describeTopic.getStatus()); + } + + // lookup for partition location + for (YdbTopic.DescribeTopicResult.PartitionInfo partition : describeTopic.getValue().getPartitionsList()) { + if (partition.getPartitionId() == targetPartitionId) { + if (!partition.hasPartitionLocation()) { + logger.warn("[{}] partition {} has no valid location info", id, targetPartitionId); + Issue issue = Issue.of("Partition " + targetPartitionId + " has no location", Issue.Severity.ERROR); + return Result.fail(Status.of(StatusCode.BAD_REQUEST, issue)); + } + + return Result.success(partition.getPartitionLocation()); + } + } + + logger.warn("[{}] topic {} doesn't have partition {}, direct writing failed", id, topicPath, targetPartitionId); + Issue issue = Issue.of("Cannot find partition " + targetPartitionId, Issue.Severity.ERROR); + return Result.fail(Status.of(StatusCode.BAD_REQUEST, issue)); + } + + private Result lookupPartitionId(String id) { + CompletableFuture> pidFuture = new CompletableFuture<>(); + + // create one-shot stream to detect partitionID for this producer + logger.info("[{}] create probe stream for topic {} with producer {}", id, topicPath, producerId); + GrpcRequestSettings settings = GrpcRequestSettings.newBuilder() + .withTraceId(id + "-probe") + .withDeadline(Duration.ofMinutes(1)) + .build(); + GrpcReadWriteStream stream = rpc.writeSession(settings); + + CompletableFuture streamFuture = stream.start(resp -> { + if (resp.getStatus() != StatusCodesProtos.StatusIds.StatusCode.SUCCESS) { + Status status = Status.of(StatusCode.fromProto(resp.getStatus()), Issue.fromPb(resp.getIssuesList())); + logger.warn("[{}] probe stream to topic {} with producer {} got error {}", id, topicPath, + producerId, status); + pidFuture.complete(Result.fail(status)); + return; + } + + if (resp.hasInitResponse()) { + long pid = resp.getInitResponse().getPartitionId(); + logger.info("[{}] probe stream to topic {} with producer {} has partition {}", id, topicPath, + producerId, pid); + pidFuture.complete(Result.success(pid)); + return; + } + + logger.warn("[{}] probe stream to topic {} with producer {} got unexpected message {}", id, topicPath, + producerId, resp.getClass().getName()); + + Issue issue = Issue.of("Unexpected message from stream with producer " + producerId, Issue.Severity.ERROR); + pidFuture.complete(Result.fail(Status.of(StatusCode.BAD_REQUEST, issue))); + }); + + if (streamFuture.isDone()) { + logger.warn("[{}] probe stream to topic {} with producer {} failed with status {}", id, topicPath, + producerId, streamFuture.join()); + return Result.fail(streamFuture.join()); + } + + try { + streamFuture.whenComplete((st, th) -> { + Status status = st != null ? st : Status.of(StatusCode.CLIENT_INTERNAL_ERROR, th); + if (pidFuture.complete(Result.fail(status))) { + logger.warn("[{}] probe stream to topic {} with producer {} failed with status {}", id, topicPath, + producerId, status); + } + }); + YdbTopic.StreamWriteMessage.FromClient init = YdbTopic.StreamWriteMessage.FromClient.newBuilder() + .setInitRequest(buildInitRequest()) + .build(); + stream.sendNext(init); + return pidFuture.join(); + } finally { + if (!streamFuture.isDone()) { + stream.close(); + } + } + } +} diff --git a/topic/src/main/java/tech/ydb/topic/write/impl/WriteStreamFactory.java b/topic/src/main/java/tech/ydb/topic/write/impl/WriteStreamFactory.java index a1920a740..2cc0459d5 100644 --- a/topic/src/main/java/tech/ydb/topic/write/impl/WriteStreamFactory.java +++ b/topic/src/main/java/tech/ydb/topic/write/impl/WriteStreamFactory.java @@ -1,24 +1,9 @@ package tech.ydb.topic.write.impl; -import java.time.Duration; -import java.util.concurrent.CompletableFuture; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import tech.ydb.core.Issue; -import tech.ydb.core.Result; -import tech.ydb.core.Status; -import tech.ydb.core.StatusCode; -import tech.ydb.core.grpc.GrpcReadWriteStream; -import tech.ydb.core.grpc.GrpcRequestSettings; -import tech.ydb.proto.StatusCodesProtos; -import tech.ydb.proto.topic.YdbTopic; -import tech.ydb.proto.topic.YdbTopic.DescribeTopicRequest; -import tech.ydb.proto.topic.YdbTopic.DescribeTopicResult; import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage; import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromClient; -import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromServer; import tech.ydb.topic.TopicRpc; import tech.ydb.topic.settings.WriterSettings; @@ -27,203 +12,48 @@ * @author Aleksandr Gorshenin */ public class WriteStreamFactory { - private static final Logger logger = LoggerFactory.getLogger(WriteStreamFactory.class); - - private final String topicPath; - private final StreamWriteMessage.InitRequest initRequest; + protected final String topicPath; protected final TopicRpc rpc; + protected final String producerId; + protected final String messageGroupId; + protected final Long partitionId; - private WriteStreamFactory(TopicRpc rpc, WriterSettings settings) { + public WriteStreamFactory(TopicRpc rpc, WriterSettings settings) { this.rpc = rpc; this.topicPath = settings.getTopicPath(); - String producerId = settings.getProducerId(); - String messageGroupId = settings.getMessageGroupId(); - Long partitionId = settings.getPartitionId(); - - StreamWriteMessage.InitRequest.Builder req = StreamWriteMessage.InitRequest.newBuilder() - .setPath(topicPath); + this.producerId = settings.getProducerId(); + this.messageGroupId = settings.getMessageGroupId(); + this.partitionId = settings.getPartitionId(); - if (producerId != null) { - req.setProducerId(producerId); - } - if (messageGroupId != null) { - if (partitionId != null) { - throw new IllegalArgumentException("Both MessageGroupId and PartitionId are set in WriterSettings"); - } - req.setMessageGroupId(messageGroupId); - } else if (partitionId != null) { - req.setPartitionId(partitionId); + if (messageGroupId != null && partitionId != null) { + throw new IllegalArgumentException("Both MessageGroupId and PartitionId are set in WriterSettings"); } - - this.initRequest = req.build(); } public String getTopicPath() { return topicPath; } - public WriteSession.Stream createNewStream(String id) { - return new WriteStream(id, rpc.writeSession(id)); - } - - public YdbTopic.StreamWriteMessage.FromClient initRequest() { - return YdbTopic.StreamWriteMessage.FromClient.newBuilder() - .setInitRequest(initRequest) - .build(); - } - - protected Result lookupNodeId(String id, long partitionId) { - logger.info("[{}] describe topic {} to look up node for partition {}", id, topicPath, partitionId); - Result describeTopic = rpc.describeTopic( - DescribeTopicRequest.newBuilder().setIncludeLocation(true).setPath(topicPath).build(), - GrpcRequestSettings.newBuilder().withDeadline(Duration.ofMinutes(1)).build() - ).join(); - - if (!describeTopic.isSuccess()) { - logger.warn("[{}] describe topic {} failed with status {}", id, topicPath, describeTopic.getStatus()); - return Result.fail(describeTopic.getStatus()); - } - - // lookup for nodeID - for (DescribeTopicResult.PartitionInfo partition : describeTopic.getValue().getPartitionsList()) { - if (partition.getPartitionId() == partitionId) { - return Result.success(partition.getPartitionLocation().getNodeId()); - } - } - - logger.warn("[{}] topic {} doesn't have partition {}, direct writing failed", id, topicPath, partitionId); - Issue issue = Issue.of("Cannot find partition " + partitionId, Issue.Severity.ERROR); - return Result.fail(Status.of(StatusCode.BAD_REQUEST, issue)); - } - - protected Result lookupPartitionId(String id, String producerId) { - CompletableFuture> partitionId = new CompletableFuture<>(); - - // create one-shot stream to detect partitionID for this producer - logger.info("[{}] create probe stream for topic {} with producer {}", id, topicPath, producerId); - GrpcRequestSettings settings = GrpcRequestSettings.newBuilder() - .withTraceId(id + "-probe") - .withDeadline(Duration.ofMinutes(1)) - .build(); - GrpcReadWriteStream stream = rpc.writeSession(settings); - - CompletableFuture streamFuture = stream.start(resp -> { - if (resp.getStatus() != StatusCodesProtos.StatusIds.StatusCode.SUCCESS) { - Status status = Status.of(StatusCode.fromProto(resp.getStatus()), Issue.fromPb(resp.getIssuesList())); - logger.warn("[{}] probe stream to topic {} with producer {} got error {}", id, topicPath, - producerId, status); - partitionId.complete(Result.fail(status)); - return; - } - - if (resp.hasInitResponse()) { - long pid = resp.getInitResponse().getPartitionId(); - logger.info("[{}] probe stream to topic {} with producer {} has partition {}", id, topicPath, - producerId, pid); - partitionId.complete(Result.success(pid)); - return; - } - - logger.warn("[{}] probe stream to topic {} with producer {} got unexpected message {}", id, topicPath, - producerId, resp.getClass().getName()); - - Issue issue = Issue.of("Unexpected message from stream with producer " + producerId, Issue.Severity.ERROR); - partitionId.complete(Result.fail(Status.of(StatusCode.BAD_REQUEST, issue))); - }); - - if (streamFuture.isDone()) { - logger.warn("[{}] probe stream to topic {} with producer {} failed with status {}", id, topicPath, - producerId, streamFuture.join()); - return Result.fail(streamFuture.join()); - } - - try { - streamFuture.whenComplete((st, th) -> { - Status status = st != null ? st : Status.of(StatusCode.CLIENT_INTERNAL_ERROR, th); - if (!partitionId.isDone()) { - logger.warn("[{}] probe stream to topic {} with producer {} failed with status {}", id, topicPath, - producerId, status); - partitionId.complete(Result.fail(status)); - } - }); - stream.sendNext(initRequest()); - return partitionId.join(); - } finally { - if (!streamFuture.isDone()) { - stream.close(); - } - } - } - - public static WriteStreamFactory of(TopicRpc rpc, WriterSettings settings) { - if (!settings.isDirectWrite()) { - return new WriteStreamFactory(rpc, settings); - } + public StreamWriteMessage.InitRequest buildInitRequest() { + StreamWriteMessage.InitRequest.Builder req = StreamWriteMessage.InitRequest.newBuilder() + .setPath(topicPath); - if (settings.getPartitionId() != null) { - return new DirectWriteByPartitionId(rpc, settings, settings.getPartitionId()); + if (producerId != null) { + req.setProducerId(producerId); } - - if (settings.getProducerId() != null) { - return new DirectWriteByProducerId(rpc, settings, settings.getProducerId()); + if (messageGroupId != null) { + req.setMessageGroupId(messageGroupId); } - - throw new IllegalArgumentException("Direct writing requires PartitionId or ProducerId in WriterSettings"); - } - - private static class DirectWriteByPartitionId extends WriteStreamFactory { - private final long partitionId; - - private DirectWriteByPartitionId(TopicRpc rpc, WriterSettings settings, long partitionId) { - super(rpc, settings); - this.partitionId = partitionId; + if (partitionId != null) { + req.setPartitionId(partitionId); } - @Override - public WriteSession.Stream createNewStream(String id) { - Result nodeId = lookupNodeId(id, partitionId); - if (!nodeId.isSuccess()) { - return new WriteStream.Fail(id, nodeId.getStatus()); - } - - GrpcRequestSettings settings = GrpcRequestSettings.newBuilder() - .withTraceId(id) - .disableDeadline() - .withDirectMode(true) - .withPreferredNodeID(nodeId.getValue()) - .build(); - return new WriteStream(id, rpc.writeSession(settings)); - } + return req.build(); } - private static class DirectWriteByProducerId extends WriteStreamFactory { - private final String producerId; - - private DirectWriteByProducerId(TopicRpc rpc, WriterSettings settings, String producerId) { - super(rpc, settings); - this.producerId = producerId; - } - - @Override - public WriteSession.Stream createNewStream(String id) { - Result partitionId = lookupPartitionId(id, producerId); - if (!partitionId.isSuccess()) { - return new WriteStream.Fail(id, partitionId.getStatus()); - } - - Result nodeId = lookupNodeId(id, partitionId.getValue()); - if (!nodeId.isSuccess()) { - return new WriteStream.Fail(id, nodeId.getStatus()); - } - - GrpcRequestSettings settings = GrpcRequestSettings.newBuilder() - .withTraceId(id) - .disableDeadline() - .withDirectMode(true) - .withPreferredNodeID(nodeId.getValue()) - .build(); - return new WriteStream(id, rpc.writeSession(settings)); - } + public WriteSession.Stream createNewStream(String id) { + FromClient init = FromClient.newBuilder().setInitRequest(buildInitRequest()).build(); + return new WriteStream(id, rpc.writeSession(id), init); } } diff --git a/topic/src/main/java/tech/ydb/topic/write/impl/WriterImpl.java b/topic/src/main/java/tech/ydb/topic/write/impl/WriterImpl.java index b209afc8a..6e8e8e2c6 100644 --- a/topic/src/main/java/tech/ydb/topic/write/impl/WriterImpl.java +++ b/topic/src/main/java/tech/ydb/topic/write/impl/WriterImpl.java @@ -47,12 +47,15 @@ public class WriterImpl { private Boolean isSeqNoProvided = null; - public WriterImpl(TopicRpc topicRpc, - WriterSettings settings, - Executor compressionExecutor, - @Nonnull CodecRegistry codecRegistry) { + public WriterImpl(TopicRpc topicRpc, WriterSettings settings, Executor compressionExecutor, + @Nonnull CodecRegistry codecRegistry) { + this(topicRpc, defaultFactory(topicRpc, settings), settings, compressionExecutor, codecRegistry); + } + + public WriterImpl(TopicRpc topicRpc, WriteStreamFactory factory, WriterSettings settings, + Executor compressionExecutor, @Nonnull CodecRegistry codecRegistry) { this.debugId = DebugTools.createDebugId(settings.getLogPrefix()); - this.stream = new WriteSession(debugId, topicRpc, settings, new ListenerImpl()); + this.stream = new WriteSession(debugId, factory, settings, topicRpc.getScheduler(), new ListenerImpl()); this.writeQueue = new WriterQueue(debugId, settings, codecRegistry, compressionExecutor, sendTask); logger.info("Writer with id {} created for topic \"{}\" with producerId \"{}\" and messageGroupId \"{}\"", @@ -180,4 +183,11 @@ public void run() { stream.sendAll(send); } } + + private static WriteStreamFactory defaultFactory(TopicRpc topicRpc, WriterSettings settings) { + if (settings.isDirectWrite()) { + return new WriteStreamDirectFactory(topicRpc, settings); + } + return new WriteStreamFactory(topicRpc, settings); + } } diff --git a/topic/src/test/java/tech/ydb/topic/TopicWritersIntegrationTest.java b/topic/src/test/java/tech/ydb/topic/TopicWritersIntegrationTest.java index 9b746cdb7..b739cfdc9 100644 --- a/topic/src/test/java/tech/ydb/topic/TopicWritersIntegrationTest.java +++ b/topic/src/test/java/tech/ydb/topic/TopicWritersIntegrationTest.java @@ -5,6 +5,7 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; @@ -21,17 +22,20 @@ import org.slf4j.LoggerFactory; import tech.ydb.core.StatusCode; +import tech.ydb.core.UnexpectedResultException; import tech.ydb.core.utils.FutureTools; import tech.ydb.test.junit4.GrpcTransportRule; import tech.ydb.topic.description.Consumer; import tech.ydb.topic.read.DeferredCommitter; import tech.ydb.topic.read.SyncReader; import tech.ydb.topic.settings.CreateTopicSettings; +import tech.ydb.topic.settings.PartitioningSettings; import tech.ydb.topic.settings.ReaderSettings; import tech.ydb.topic.settings.TopicReadSettings; import tech.ydb.topic.settings.TopicRetryConfig; import tech.ydb.topic.settings.WriterSettings; import tech.ydb.topic.write.AsyncWriter; +import tech.ydb.topic.write.InitResult; import tech.ydb.topic.write.Message; import tech.ydb.topic.write.QueueOverflowException; import tech.ydb.topic.write.SyncWriter; @@ -54,6 +58,7 @@ public class TopicWritersIntegrationTest { private final static String TEST_PRODUCER1 = "producer"; private final static String TEST_CONSUMER1 = "consumer"; + private final static int PARTITIONS_COUNT = 1; private static TopicClient client; @@ -74,6 +79,10 @@ public void initTopic() { logger.info("Create test topic {} ...", TEST_TOPIC); client.createTopic(TEST_TOPIC, CreateTopicSettings.newBuilder() .addConsumer(Consumer.newBuilder().setName(TEST_CONSUMER1).build()) + .setPartitioningSettings(PartitioningSettings.newBuilder() + .setMaxActivePartitions(PARTITIONS_COUNT) + .setMinActivePartitions(PARTITIONS_COUNT) + .build()) .build()) .join().expectSuccess("can't create a new topic"); } @@ -363,4 +372,45 @@ public void idempotentWriterTest() throws Exception { writer2.shutdown().join(); } + + @Test + public void wrongDirectWriteTest() throws Exception { + CountDownLatch closed = new CountDownLatch(1); + + WriterSettings settings = WriterSettings.newBuilder() + .setTopicPath(TEST_TOPIC) + .setDirectWrite(true) + .setPartitionId(PARTITIONS_COUNT + 1) // Invalid partition + .setRetryConfig(TopicRetryConfig.STANDARD) + .setErrorsHandler((t, u) -> { closed.countDown(); }) + .build(); + + AsyncWriter writer = client.createAsyncWriter(settings); + CompletableFuture f1 = writer.send(Message.of(new byte[] { 0x00 })); + CompletableFuture f2 = writer.init(); + + Assert.assertTrue(closed.await(5, TimeUnit.SECONDS)); + + CompletableFuture f3 = writer.shutdown(); + + Assert.assertTrue(f1.isCompletedExceptionally()); + Assert.assertTrue(f2.isCompletedExceptionally()); + Assert.assertFalse(f3.isCompletedExceptionally()); + + Exception ex1 = Assert.assertThrows(CompletionException.class, f1::join); + Exception ex2 = Assert.assertThrows(CompletionException.class, f2::join); + + Assert.assertTrue(ex1.getCause() instanceof RuntimeException); + Assert.assertTrue(ex2.getCause() instanceof UnexpectedResultException); + + String reason = "Cannot find partition " + (PARTITIONS_COUNT + 1) + " (S_ERROR)"; + Assert.assertEquals( + "Message sending was cancelled with Status{code = BAD_REQUEST(code=400010), issues = [" + reason + "]}", + ex1.getCause().getMessage() + ); + Assert.assertEquals( + "Cannot init write session, code: BAD_REQUEST, issues: [" + reason + "]", + ex2.getCause().getMessage() + ); + } } diff --git a/topic/src/test/java/tech/ydb/topic/impl/TopicRetryableStreamTest.java b/topic/src/test/java/tech/ydb/topic/impl/TopicRetryableStreamTest.java index 3ef82d174..f09eab067 100644 --- a/topic/src/test/java/tech/ydb/topic/impl/TopicRetryableStreamTest.java +++ b/topic/src/test/java/tech/ydb/topic/impl/TopicRetryableStreamTest.java @@ -42,14 +42,14 @@ private static class StreamHandle { StreamHandle(TopicStreamBase mocked) { this.stream = mocked; - Mockito.when(mocked.start(Mockito.any(), Mockito.any())).thenReturn(grpcFuture); + Mockito.when(mocked.start(Mockito.any())).thenReturn(grpcFuture); } StreamHandle() { Mockito.when(grpc.authToken()).thenReturn("token"); Mockito.when(grpc.start(Mockito.any())).thenReturn(grpcFuture); - stream = new TopicStreamBase(logger, "inner", grpc) { + stream = new TopicStreamBase(logger, "inner", grpc, EMPTY) { @Override protected Empty updateTokenMessage(String token) { return EMPTY; @@ -89,11 +89,6 @@ protected TopicStream createNewStream(String debugId) { return handles.get(handleIndex++).stream; } - @Override - protected Empty getInitRequest() { - return EMPTY; - } - @Override protected void onNext(Empty message) { receivedMessages.add(message); diff --git a/topic/src/test/java/tech/ydb/topic/impl/TopicStreamTest.java b/topic/src/test/java/tech/ydb/topic/impl/TopicStreamTest.java index 642195fc2..c63b77e65 100644 --- a/topic/src/test/java/tech/ydb/topic/impl/TopicStreamTest.java +++ b/topic/src/test/java/tech/ydb/topic/impl/TopicStreamTest.java @@ -26,7 +26,7 @@ private interface MockedStream extends GrpcReadWriteStream { TestStream(MockedStream mock) { - super(logger, "test", mock); + super(logger, "test", mock, msg("init")); } @Override @@ -65,7 +65,7 @@ public void baseTest() { List received = new ArrayList<>(); TestStream stream = new TestStream(mock); - CompletableFuture result = stream.start(msg("init"), received::add); + CompletableFuture result = stream.start(received::add); Mockito.verify(mock).start(observer.capture()); stream.send(msg("s1")); @@ -104,7 +104,7 @@ public void startStreamAndImmediatelyFinishTest() { MockedStream mock = buildMockedStream("token", status); TestStream stream = new TestStream(mock); - stream.start(msg("init-req"), msg -> {}); + stream.start(msg -> {}); Mockito.verify(mock).start(Mockito.any()); stream.send(msg("s1")); @@ -123,7 +123,7 @@ public void nonSuccessMessageStopsStreamTest() { List received = new ArrayList<>(); TestStream stream = new TestStream(mock); - CompletableFuture result = stream.start(msg("init"), received::add); + CompletableFuture result = stream.start(received::add); Mockito.verify(mock).start(observer.capture()); @@ -153,7 +153,7 @@ public void tokenUpdatesTest() { Mockito.when(mock.start(Mockito.any())).thenReturn(streamFuture); TestStream stream = new TestStream(mock); - stream.start(msg("init"), msg -> {}); + stream.start(msg -> {}); Mockito.verify(mock).start(Mockito.any()); diff --git a/topic/src/test/java/tech/ydb/topic/write/impl/WriteStreamDirectFactoryTest.java b/topic/src/test/java/tech/ydb/topic/write/impl/WriteStreamDirectFactoryTest.java new file mode 100644 index 000000000..9dbc2ed90 --- /dev/null +++ b/topic/src/test/java/tech/ydb/topic/write/impl/WriteStreamDirectFactoryTest.java @@ -0,0 +1,443 @@ +package tech.ydb.topic.write.impl; + +import java.util.Arrays; +import java.util.concurrent.CompletableFuture; + +import org.junit.Assert; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +import tech.ydb.core.Issue; +import tech.ydb.core.Result; +import tech.ydb.core.Status; +import tech.ydb.core.StatusCode; +import tech.ydb.core.grpc.GrpcReadStream; +import tech.ydb.core.grpc.GrpcReadWriteStream; +import tech.ydb.core.grpc.GrpcRequestSettings; +import tech.ydb.proto.StatusCodesProtos; +import tech.ydb.proto.topic.YdbTopic; +import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromClient; +import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromServer; +import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.InitResponse; +import tech.ydb.topic.TopicRpc; +import tech.ydb.topic.settings.WriterSettings; + +/** + * + * @author Aleksandr Gorshenin {@literal } + */ +public class WriteStreamDirectFactoryTest { + private static class MockedStream { + @SuppressWarnings("unchecked") + private final GrpcReadWriteStream grpc = Mockito.mock(GrpcReadWriteStream.class); + @SuppressWarnings("unchecked") + private final ArgumentCaptor> observer = ArgumentCaptor + .forClass(GrpcReadStream.Observer.class); + + private final CompletableFuture result = new CompletableFuture<>(); + private final ArgumentCaptor msg = ArgumentCaptor.forClass(FromClient.class); + + public MockedStream() { + Mockito.when(grpc.authToken()).thenReturn(""); + Mockito.when(grpc.start(observer.capture())).thenReturn(result); + } + + public FromClient verifyNextMsg() { + Mockito.verify(grpc).sendNext(msg.capture()); + return msg.getValue(); + } + + public void responseWith(FromServer response) { + Mockito.doAnswer((iom) -> { + observer.getValue().onNext(response); + return null; + }).when(grpc).sendNext(Mockito.any()); + } + + public void responseWith(Status status) { + Mockito.doAnswer((iom) -> { + result.complete(status); + return null; + }).when(grpc).sendNext(Mockito.any()); + } + + public void responseWith(Exception ex) { + Mockito.doAnswer((iom) -> { + result.completeExceptionally(ex); + return null; + }).when(grpc).sendNext(Mockito.any()); + } + + public void closeImmediately(Status status) { + result.complete(status); + } + + public void fail(FromServer response) { + Mockito.doAnswer((iom) -> { + observer.getValue().onNext(response); + return null; + }).when(grpc).sendNext(Mockito.any()); + } + } + + private static YdbTopic.DescribeTopicResult.PartitionInfo partition(long partitionId, int nodeId, long generation) { + return YdbTopic.DescribeTopicResult.PartitionInfo.newBuilder() + .setPartitionId(partitionId) + .setPartitionLocation(YdbTopic.PartitionLocation.newBuilder() + .setNodeId(nodeId) + .setGeneration(generation) + .build()) + .build(); + } + + private static void mockDescribeResult(TopicRpc rpc, YdbTopic.DescribeTopicResult.PartitionInfo... partitions) { + Mockito.when(rpc.describeTopic(Mockito.any(), Mockito.any())) + .thenReturn(CompletableFuture.completedFuture(Result.success( + YdbTopic.DescribeTopicResult.newBuilder().addAllPartitions(Arrays.asList(partitions)).build()) + )); + } + + private static void mockDescribeResult(TopicRpc rpc, Status status) { + Mockito.when(rpc.describeTopic(Mockito.any(), Mockito.any())) + .thenReturn(CompletableFuture.completedFuture(Result.fail(status))); + } + + @Test + public void invalidDirectWriteTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + WriterSettings settings = WriterSettings.newBuilder() + .setTopicPath("/local/topic") + .setDirectWrite(true) // requires producerId or partitionId + .build(); + + Exception ex = Assert.assertThrows( + IllegalArgumentException.class, + () -> new WriteStreamDirectFactory(rpc, settings) + ); + Assert.assertEquals("Direct writing requires PartitionId or ProducerId in WriterSettings", ex.getMessage()); + } + + @Test + public void directWriteByPartitionIdTest() { + MockedStream mocked = new MockedStream(); + TopicRpc rpc = Mockito.mock(TopicRpc.class); + + mockDescribeResult(rpc, partition(1L, 10, 3L), partition(2L, 42, 1L), partition(3L, 23, 2L)); + Mockito.when(rpc.writeSession(Mockito.any(GrpcRequestSettings.class))).thenReturn(mocked.grpc); + + WriterSettings settings = WriterSettings.newBuilder() + .setTopicPath("/local/topic") + .setPartitionId(2L) + .setDirectWrite(true) + .build(); + + // just verify it doesn't throw and returns a factory for the correct topic + WriteStreamFactory factory = new WriteStreamDirectFactory(rpc, settings); + Assert.assertEquals("/local/topic", factory.getTopicPath()); + + WriteSession.Stream stream = factory.createNewStream("s1"); + Assert.assertTrue(stream instanceof WriteStream); + + ArgumentCaptor options = ArgumentCaptor.forClass(GrpcRequestSettings.class); + Mockito.verify(rpc).writeSession(options.capture()); + Assert.assertTrue(options.getValue().isDirectMode()); + Assert.assertEquals(42, options.getValue().getPreferredNodeID().intValue()); + + stream.start(null); + + FromClient msg = mocked.verifyNextMsg(); + Assert.assertTrue(msg.hasInitRequest()); + Assert.assertEquals("/local/topic", msg.getInitRequest().getPath()); + Assert.assertFalse(msg.getInitRequest().hasPartitionId()); + Assert.assertEquals("", msg.getInitRequest().getProducerId()); + Assert.assertEquals(2L, msg.getInitRequest().getPartitionWithGeneration().getPartitionId()); + Assert.assertEquals(1L, msg.getInitRequest().getPartitionWithGeneration().getGeneration()); + } + + @Test + public void directWriteByPartitionIdTestDescribeFailTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + mockDescribeResult(rpc, Status.of(StatusCode.UNAVAILABLE)); + + WriteStreamFactory factory = new WriteStreamDirectFactory(rpc, WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .setPartitionId(3L) + .setDirectWrite(true) + .build()); + + WriteSession.Stream stream = factory.createNewStream("s1"); + + Mockito.verify(rpc, Mockito.never()).writeSession(Mockito.any(GrpcRequestSettings.class)); + + Assert.assertTrue(stream instanceof WriteStream.Fail); + CompletableFuture res = stream.start(null); + Assert.assertTrue(res.isDone()); + Assert.assertEquals(Status.of(StatusCode.UNAVAILABLE), res.join()); + + stream.close(); // no effect + } + + @Test + public void directWriteByPartitionIdTestPartitionNotFoundTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + // result has partition 5, but we're looking for partition 3 + mockDescribeResult(rpc, partition(4L, 99, 1L), partition(5L, 100, 2L)); + + WriteStreamFactory factory = new WriteStreamDirectFactory(rpc, WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .setPartitionId(3L) + .setDirectWrite(true) + .build()); + + WriteSession.Stream stream = factory.createNewStream("s1"); + + Mockito.verify(rpc, Mockito.never()).writeSession(Mockito.any(GrpcRequestSettings.class)); + + Assert.assertTrue(stream instanceof WriteStream.Fail); + CompletableFuture res = stream.start(null); + Assert.assertTrue(res.isDone()); + Status expected = Status.of(StatusCode.BAD_REQUEST, Issue.of("Cannot find partition 3", Issue.Severity.ERROR)); + Assert.assertEquals(expected, res.join()); + + stream.close(); // no effect + } + + @Test + public void directWriteByPartitionIdTestPartitionHasNoLocationTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + + mockDescribeResult(rpc, YdbTopic.DescribeTopicResult.PartitionInfo.newBuilder() + .setPartitionId(3L) + .build()); + + WriteStreamFactory factory = new WriteStreamDirectFactory(rpc, WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .setPartitionId(3L) + .setDirectWrite(true) + .build()); + + WriteSession.Stream stream = factory.createNewStream("s1"); + + Mockito.verify(rpc, Mockito.never()).writeSession(Mockito.any(GrpcRequestSettings.class)); + + Assert.assertTrue(stream instanceof WriteStream.Fail); + CompletableFuture res = stream.start(null); + Assert.assertTrue(res.isDone()); + Status expected = Status.of(StatusCode.BAD_REQUEST, Issue.of("Partition 3 has no location", Issue.Severity.ERROR)); + Assert.assertEquals(expected, res.join()); + + stream.close(); // no effect + } + + @Test + public void directWriteByProducerIdTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + + MockedStream probe = new MockedStream(); + MockedStream actual = new MockedStream(); + + Mockito.when(rpc.writeSession(Mockito.any(GrpcRequestSettings.class))) + .thenReturn(probe.grpc).thenReturn(actual.grpc); + + mockDescribeResult(rpc, partition(7L, 55, 3L)); + + probe.responseWith(FromServer.newBuilder() + .setStatus(StatusCodesProtos.StatusIds.StatusCode.SUCCESS) + .setInitResponse(InitResponse.newBuilder() + .setLastSeqNo(0) + .setPartitionId(7L) + .setSessionId("session") + .build()) + .build()); + + WriteStreamFactory factory = new WriteStreamDirectFactory(rpc, WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .setProducerId("producer-1") + .setMessageGroupId("producer-1") + .setDirectWrite(true) + .build()); + + WriteSession.Stream stream = factory.createNewStream("s1"); + Assert.assertTrue(stream instanceof WriteStream); + + ArgumentCaptor options = ArgumentCaptor.forClass(GrpcRequestSettings.class); + Mockito.verify(rpc, Mockito.times(2)).writeSession(options.capture()); + Assert.assertTrue(options.getValue().isDirectMode()); + Assert.assertEquals(55, options.getValue().getPreferredNodeID().intValue()); + + stream.start(null); + + FromClient msg = actual.verifyNextMsg(); + Assert.assertTrue(msg.hasInitRequest()); + Assert.assertEquals("/test/topic", msg.getInitRequest().getPath()); + Assert.assertFalse(msg.getInitRequest().hasPartitionId()); + Assert.assertEquals("producer-1", msg.getInitRequest().getProducerId()); + Assert.assertEquals("", msg.getInitRequest().getMessageGroupId()); // never used for direct-write + Assert.assertEquals(7L, msg.getInitRequest().getPartitionWithGeneration().getPartitionId()); + Assert.assertEquals(3L, msg.getInitRequest().getPartitionWithGeneration().getGeneration()); + } + + @Test + public void directWriteByProducerIdProbeFailTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + + MockedStream probe = new MockedStream(); + probe.closeImmediately(Status.of(StatusCode.UNAUTHORIZED)); + Mockito.when(rpc.writeSession(Mockito.any(GrpcRequestSettings.class))).thenReturn(probe.grpc); + + WriteStreamFactory factory = new WriteStreamDirectFactory(rpc, WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .setProducerId("producer-1") + .setDirectWrite(true) + .build()); + + WriteSession.Stream stream = factory.createNewStream("s1"); + Assert.assertTrue(stream instanceof WriteStream.Fail); + Mockito.verify(rpc).writeSession(Mockito.any(GrpcRequestSettings.class)); + + CompletableFuture res = stream.start(null); + Assert.assertTrue(res.isDone()); + Assert.assertEquals(Status.of(StatusCode.UNAUTHORIZED), res.join()); + stream.close(); // no effect + } + + @Test + public void directWriteByProducerIdProbeFailOnSendTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + + MockedStream probe = new MockedStream(); + probe.responseWith(Status.of(StatusCode.PRECONDITION_FAILED)); + Mockito.when(rpc.writeSession(Mockito.any(GrpcRequestSettings.class))).thenReturn(probe.grpc); + + WriteStreamFactory factory = new WriteStreamDirectFactory(rpc, WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .setProducerId("producer-1") + .setDirectWrite(true) + .build()); + + WriteSession.Stream stream = factory.createNewStream("s1"); + Assert.assertTrue(stream instanceof WriteStream.Fail); + Mockito.verify(rpc).writeSession(Mockito.any(GrpcRequestSettings.class)); + + CompletableFuture res = stream.start(null); + Assert.assertTrue(res.isDone()); + Assert.assertEquals(Status.of(StatusCode.PRECONDITION_FAILED), res.join()); + stream.close(); // no effect + } + + @Test + public void directWriteByProducerIdProbeExceptionOnSendTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + + MockedStream probe = new MockedStream(); + probe.responseWith(new RuntimeException("something went wrong")); + Mockito.when(rpc.writeSession(Mockito.any(GrpcRequestSettings.class))).thenReturn(probe.grpc); + + WriteStreamFactory factory = new WriteStreamDirectFactory(rpc, WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .setProducerId("producer-1") + .setDirectWrite(true) + .build()); + + WriteSession.Stream stream = factory.createNewStream("s1"); + Assert.assertTrue(stream instanceof WriteStream.Fail); + Mockito.verify(rpc).writeSession(Mockito.any(GrpcRequestSettings.class)); + + CompletableFuture res = stream.start(null); + Assert.assertTrue(res.isDone()); + Status status = res.join(); + Assert.assertEquals(StatusCode.CLIENT_INTERNAL_ERROR, status.getCode()); + Assert.assertNotNull(status.getCause()); + Assert.assertEquals("something went wrong", status.getCause().getMessage()); + stream.close(); // no effect + } + + @Test + public void directWriteByProducerIdProbeWrongResponseTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + + MockedStream probe = new MockedStream(); + probe.responseWith(FromServer.newBuilder() + .setStatus(StatusCodesProtos.StatusIds.StatusCode.INTERNAL_ERROR) + .build()); + + Mockito.when(rpc.writeSession(Mockito.any(GrpcRequestSettings.class))).thenReturn(probe.grpc); + + WriteStreamFactory factory = new WriteStreamDirectFactory(rpc, WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .setProducerId("producer-1") + .setDirectWrite(true) + .build()); + + WriteSession.Stream stream = factory.createNewStream("s1"); + Assert.assertTrue(stream instanceof WriteStream.Fail); + Mockito.verify(rpc).writeSession(Mockito.any(GrpcRequestSettings.class)); + + CompletableFuture res = stream.start(null); + Assert.assertTrue(res.isDone()); + Assert.assertEquals(Status.of(StatusCode.INTERNAL_ERROR), res.join()); + stream.close(); // no effect + } + + @Test + public void directWriteByProducerIdProbeUnexpectedResponseTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + + MockedStream probe = new MockedStream(); + probe.responseWith(FromServer.newBuilder() + .setStatus(StatusCodesProtos.StatusIds.StatusCode.SUCCESS) + .setUpdateTokenResponse(YdbTopic.UpdateTokenResponse.newBuilder().build()) + .build()); + + Mockito.when(rpc.writeSession(Mockito.any(GrpcRequestSettings.class))).thenReturn(probe.grpc); + + WriteStreamFactory factory = new WriteStreamDirectFactory(rpc, WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .setProducerId("producer-1") + .setDirectWrite(true) + .build()); + + WriteSession.Stream stream = factory.createNewStream("s1"); + Assert.assertTrue(stream instanceof WriteStream.Fail); + Mockito.verify(rpc).writeSession(Mockito.any(GrpcRequestSettings.class)); + + CompletableFuture res = stream.start(null); + Assert.assertTrue(res.isDone()); + Issue issue = Issue.of("Unexpected message from stream with producer producer-1", Issue.Severity.ERROR); + Assert.assertEquals(Status.of(StatusCode.BAD_REQUEST, issue), res.join()); + stream.close(); // no effect + } + + @Test + public void directWriteByProducerIdPartitionNotFoundTest() { + TopicRpc rpc = Mockito.mock(TopicRpc.class); + + MockedStream probe = new MockedStream(); + probe.responseWith(FromServer.newBuilder() + .setStatus(StatusCodesProtos.StatusIds.StatusCode.SUCCESS) + .setInitResponse(InitResponse.newBuilder() + .setLastSeqNo(0) + .setPartitionId(5L) + .setSessionId("session") + .build()) + .build()); + + mockDescribeResult(rpc, partition(1L, 55, 8L), partition(2L, 55, 7L)); + + Mockito.when(rpc.writeSession(Mockito.any(GrpcRequestSettings.class))).thenReturn(probe.grpc); + + WriteStreamFactory factory = new WriteStreamDirectFactory(rpc, WriterSettings.newBuilder() + .setTopicPath("/test/topic") + .setProducerId("producer-1") + .setDirectWrite(true) + .build()); + + WriteSession.Stream stream = factory.createNewStream("s1"); + Assert.assertTrue(stream instanceof WriteStream.Fail); + CompletableFuture res = stream.start(null); + Assert.assertTrue(res.isDone()); + Status expected = Status.of(StatusCode.BAD_REQUEST, Issue.of("Cannot find partition 5", Issue.Severity.ERROR)); + Assert.assertEquals(expected, res.join()); + } +} diff --git a/topic/src/test/java/tech/ydb/topic/write/impl/WriteStreamFactoryTest.java b/topic/src/test/java/tech/ydb/topic/write/impl/WriteStreamFactoryTest.java index ce48d8183..267fdae10 100644 --- a/topic/src/test/java/tech/ydb/topic/write/impl/WriteStreamFactoryTest.java +++ b/topic/src/test/java/tech/ydb/topic/write/impl/WriteStreamFactoryTest.java @@ -1,22 +1,12 @@ package tech.ydb.topic.write.impl; -import java.util.Arrays; -import java.util.concurrent.CompletableFuture; import org.junit.Assert; import org.junit.Test; import org.mockito.Mockito; -import tech.ydb.core.Issue; -import tech.ydb.core.Result; -import tech.ydb.core.Status; -import tech.ydb.core.StatusCode; -import tech.ydb.core.grpc.GrpcReadStream; import tech.ydb.core.grpc.GrpcReadWriteStream; -import tech.ydb.core.grpc.GrpcRequestSettings; -import tech.ydb.proto.StatusCodesProtos; import tech.ydb.proto.topic.YdbTopic; -import tech.ydb.proto.topic.YdbTopic.DescribeTopicResult; import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromClient; import tech.ydb.proto.topic.YdbTopic.StreamWriteMessage.FromServer; import tech.ydb.topic.TopicRpc; @@ -26,57 +16,11 @@ * @author Aleksandr Gorshenin */ public class WriteStreamFactoryTest { - - @SuppressWarnings("unchecked") - private static GrpcReadWriteStream mockGrpcStream() { - GrpcReadWriteStream grpc = Mockito.mock(GrpcReadWriteStream.class); - Mockito.when(grpc.authToken()).thenReturn(""); - return grpc; - } - - private static void mockStreamError(GrpcReadWriteStream mock, Status error) { - Mockito.when(mock.start(Mockito.any())).thenReturn(CompletableFuture.completedFuture(error)); - } - - private static void mockStreamResponse(GrpcReadWriteStream mock, FromServer response) { - CompletableFuture result = new CompletableFuture<>(); - - Mockito.when(mock.start(Mockito.any())).thenAnswer(iom -> { - GrpcReadStream.Observer obs = iom.getArgument(0); - obs.onNext(response); - return result; - }).thenReturn(result); - - Mockito.doAnswer((iom) -> { - result.complete(Status.SUCCESS); - return null; - }).when(mock).close(); - } - - private static DescribeTopicResult.PartitionInfo partition(long partitionId, int nodeId) { - return DescribeTopicResult.PartitionInfo.newBuilder() - .setPartitionId(partitionId) - .setPartitionLocation(YdbTopic.PartitionLocation.newBuilder() - .setNodeId(nodeId) - .build()) - .build(); - } - - private static void mockDescribeResult(TopicRpc rpc, DescribeTopicResult.PartitionInfo... partitions) { - Mockito.when(rpc.describeTopic(Mockito.any(), Mockito.any())) - .thenReturn(CompletableFuture.completedFuture(Result.success( - DescribeTopicResult.newBuilder().addAllPartitions(Arrays.asList(partitions)).build()) - )); - } - - private static void mockDescribeResult(TopicRpc rpc, Status status) { - Mockito.when(rpc.describeTopic(Mockito.any(), Mockito.any())) - .thenReturn(CompletableFuture.completedFuture(Result.fail(status))); - } - @Test public void regularWriteTest() { - GrpcReadWriteStream grpc = mockGrpcStream(); + @SuppressWarnings("unchecked") + GrpcReadWriteStream grpc = Mockito.mock(GrpcReadWriteStream.class); + TopicRpc rpc = Mockito.mock(TopicRpc.class); Mockito.when(rpc.writeSession(Mockito.eq("s1"))).thenReturn(grpc); @@ -84,7 +28,7 @@ public void regularWriteTest() { .setTopicPath("/local/topic") .build(); - WriteStreamFactory factory = WriteStreamFactory.of(rpc, settings); + WriteStreamFactory factory = new WriteStreamFactory(rpc, settings); Assert.assertEquals("/local/topic", factory.getTopicPath()); WriteSession.Stream stream = factory.createNewStream("s1"); @@ -95,12 +39,11 @@ public void regularWriteTest() { @Test public void writeWithoutDeduplicationTest() { TopicRpc rpc = Mockito.mock(TopicRpc.class); - WriteStreamFactory factory = WriteStreamFactory.of(rpc, WriterSettings.newBuilder() + WriteStreamFactory factory = new WriteStreamFactory(rpc, WriterSettings.newBuilder() .setTopicPath("/test/topic") .build()); - YdbTopic.StreamWriteMessage.InitRequest req = factory.initRequest() - .getInitRequest(); + YdbTopic.StreamWriteMessage.InitRequest req = factory.buildInitRequest(); Assert.assertEquals("/test/topic", req.getPath()); Assert.assertEquals("", req.getProducerId()); Assert.assertFalse(req.hasMessageGroupId()); @@ -110,13 +53,12 @@ public void writeWithoutDeduplicationTest() { @Test public void writeWithProducerIdTest() { TopicRpc rpc = Mockito.mock(TopicRpc.class); - WriteStreamFactory factory = WriteStreamFactory.of(rpc, WriterSettings.newBuilder() + WriteStreamFactory factory = new WriteStreamFactory(rpc, WriterSettings.newBuilder() .setTopicPath("/test/topic") .setProducerId("producer") .build()); - YdbTopic.StreamWriteMessage.InitRequest req = factory.initRequest() - .getInitRequest(); + YdbTopic.StreamWriteMessage.InitRequest req = factory.buildInitRequest(); Assert.assertEquals("/test/topic", req.getPath()); Assert.assertEquals("producer", req.getProducerId()); Assert.assertFalse(req.hasMessageGroupId()); @@ -126,14 +68,13 @@ public void writeWithProducerIdTest() { @Test public void writeWithProducerIdAndMessageGroupIdTest() { TopicRpc rpc = Mockito.mock(TopicRpc.class); - WriteStreamFactory factory = WriteStreamFactory.of(rpc, WriterSettings.newBuilder() + WriteStreamFactory factory = new WriteStreamFactory(rpc, WriterSettings.newBuilder() .setTopicPath("/test/topic") .setProducerId("producer") .setMessageGroupId("producer") .build()); - YdbTopic.StreamWriteMessage.InitRequest req = factory.initRequest() - .getInitRequest(); + YdbTopic.StreamWriteMessage.InitRequest req = factory.buildInitRequest(); Assert.assertEquals("/test/topic", req.getPath()); Assert.assertEquals("producer", req.getProducerId()); Assert.assertEquals("producer", req.getMessageGroupId()); @@ -143,12 +84,12 @@ public void writeWithProducerIdAndMessageGroupIdTest() { @Test public void writeWithPartitionIdTest() { TopicRpc rpc = Mockito.mock(TopicRpc.class); - WriteStreamFactory factory = WriteStreamFactory.of(rpc, WriterSettings.newBuilder() + WriteStreamFactory factory = new WriteStreamFactory(rpc, WriterSettings.newBuilder() .setTopicPath("/test/topic") .setPartitionId(5L) .build()); - YdbTopic.StreamWriteMessage.InitRequest req = factory.initRequest().getInitRequest(); + YdbTopic.StreamWriteMessage.InitRequest req = factory.buildInitRequest(); Assert.assertEquals(5L, req.getPartitionId()); Assert.assertFalse(req.hasMessageGroupId()); } @@ -161,244 +102,7 @@ public void messageGroupAndPartitionErrorTest() { .setMessageGroupId("group-1") .setPartitionId(5L) .build(); - Exception ex = Assert.assertThrows(IllegalArgumentException.class, () -> WriteStreamFactory.of(rpc, settings)); + Exception ex = Assert.assertThrows(IllegalArgumentException.class, () -> new WriteStreamFactory(rpc, settings)); Assert.assertEquals("Both MessageGroupId and PartitionId are set in WriterSettings", ex.getMessage()); } - - @Test - public void invalidDirectWriteTest() { - TopicRpc rpc = Mockito.mock(TopicRpc.class); - WriterSettings settings = WriterSettings.newBuilder() - .setTopicPath("/local/topic") - .setDirectWrite(true) // requires producerId or partitionId - .build(); - - Exception ex = Assert.assertThrows(IllegalArgumentException.class, () -> WriteStreamFactory.of(rpc, settings)); - Assert.assertEquals("Direct writing requires PartitionId or ProducerId in WriterSettings", ex.getMessage()); - } - - @Test - public void directWriteByPartitionIdTest() { - GrpcReadWriteStream grpc = mockGrpcStream(); - TopicRpc rpc = Mockito.mock(TopicRpc.class); - - mockDescribeResult(rpc, partition(1L, 10), partition(2L, 42), partition(3L, 23)); - Mockito.when(rpc.writeSession(Mockito.any(GrpcRequestSettings.class))).thenReturn(grpc); - - WriterSettings settings = WriterSettings.newBuilder() - .setTopicPath("/local/topic") - .setPartitionId(2L) - .setDirectWrite(true) - .build(); - - // just verify it doesn't throw and returns a factory for the correct topic - WriteStreamFactory factory = WriteStreamFactory.of(rpc, settings); - Assert.assertEquals("/local/topic", factory.getTopicPath()); - - WriteSession.Stream stream = factory.createNewStream("s1"); - Assert.assertTrue(stream instanceof WriteStream); - Mockito.verify(rpc).writeSession(Mockito.any(GrpcRequestSettings.class)); - } - - @Test - public void directWriteByPartitionIdTestDescribeFailTest() { - TopicRpc rpc = Mockito.mock(TopicRpc.class); - mockDescribeResult(rpc, Status.of(StatusCode.UNAVAILABLE)); - - WriteStreamFactory factory = WriteStreamFactory.of(rpc, WriterSettings.newBuilder() - .setTopicPath("/test/topic") - .setPartitionId(3L) - .setDirectWrite(true) - .build()); - - WriteSession.Stream stream = factory.createNewStream("s1"); - - Mockito.verify(rpc, Mockito.never()).writeSession(Mockito.any(GrpcRequestSettings.class)); - - Assert.assertTrue(stream instanceof WriteStream.Fail); - CompletableFuture res = stream.start(null, null); - Assert.assertTrue(res.isDone()); - Assert.assertEquals(Status.of(StatusCode.UNAVAILABLE), res.join()); - - stream.close(); // no effect - } - - @Test - public void directWriteByPartitionIdTestPartitionNotFoundTest() { - TopicRpc rpc = Mockito.mock(TopicRpc.class); - // result has partition 5, but we're looking for partition 3 - mockDescribeResult(rpc, partition(4L, 99), partition(5L, 100)); - - WriteStreamFactory factory = WriteStreamFactory.of(rpc, WriterSettings.newBuilder() - .setTopicPath("/test/topic") - .setPartitionId(3L) - .setDirectWrite(true) - .build()); - - WriteSession.Stream stream = factory.createNewStream("s1"); - - Mockito.verify(rpc, Mockito.never()).writeSession(Mockito.any(GrpcRequestSettings.class)); - - Assert.assertTrue(stream instanceof WriteStream.Fail); - CompletableFuture res = stream.start(null, null); - Assert.assertTrue(res.isDone()); - Status expected = Status.of(StatusCode.BAD_REQUEST, Issue.of("Cannot find partition 3", Issue.Severity.ERROR)); - Assert.assertEquals(expected, res.join()); - - stream.close(); // no effect - } - - @Test - public void directWriteByProducerIdTest() { - TopicRpc rpc = Mockito.mock(TopicRpc.class); - - GrpcReadWriteStream probeGrpc = mockGrpcStream(); - GrpcReadWriteStream actualGrpc = mockGrpcStream(); - - FromServer initResponse = FromServer.newBuilder() - .setStatus(StatusCodesProtos.StatusIds.StatusCode.SUCCESS) - .setInitResponse(YdbTopic.StreamWriteMessage.InitResponse.newBuilder() - .setLastSeqNo(0) - .setPartitionId(7L) - .setSessionId("session") - .build()) - .build(); - - mockStreamResponse(probeGrpc, initResponse); - mockDescribeResult(rpc, partition(7L, 55)); - - Mockito.when(rpc.writeSession(Mockito.any(GrpcRequestSettings.class))) - .thenReturn(probeGrpc).thenReturn(actualGrpc); - - WriteStreamFactory factory = WriteStreamFactory.of(rpc, WriterSettings.newBuilder() - .setTopicPath("/test/topic") - .setProducerId("producer-1") - .setDirectWrite(true) - .build()); - - WriteSession.Stream stream = factory.createNewStream("s1"); - Assert.assertTrue(stream instanceof WriteStream); - Mockito.verify(rpc, Mockito.times(2)).writeSession(Mockito.any(GrpcRequestSettings.class)); - } - - @Test - public void directWriteByProducerIdProbeFailTest() { - TopicRpc rpc = Mockito.mock(TopicRpc.class); - - GrpcReadWriteStream probeGrpc = mockGrpcStream(); - - mockStreamError(probeGrpc, Status.of(StatusCode.UNAUTHORIZED)); - - Mockito.when(rpc.writeSession(Mockito.any(GrpcRequestSettings.class))).thenReturn(probeGrpc); - - WriteStreamFactory factory = WriteStreamFactory.of(rpc, WriterSettings.newBuilder() - .setTopicPath("/test/topic") - .setProducerId("producer-1") - .setDirectWrite(true) - .build()); - - WriteSession.Stream stream = factory.createNewStream("s1"); - Assert.assertTrue(stream instanceof WriteStream.Fail); - Mockito.verify(rpc).writeSession(Mockito.any(GrpcRequestSettings.class)); - - CompletableFuture res = stream.start(null, null); - Assert.assertTrue(res.isDone()); - Assert.assertEquals(Status.of(StatusCode.UNAUTHORIZED), res.join()); - stream.close(); // no effect - } - - @Test - public void directWriteByProducerIdProbeWrongResponseTest() { - TopicRpc rpc = Mockito.mock(TopicRpc.class); - - GrpcReadWriteStream probeGrpc = mockGrpcStream(); - - FromServer initResponse = FromServer.newBuilder() - .setStatus(StatusCodesProtos.StatusIds.StatusCode.INTERNAL_ERROR) - .build(); - mockStreamResponse(probeGrpc, initResponse); - - Mockito.when(rpc.writeSession(Mockito.any(GrpcRequestSettings.class))).thenReturn(probeGrpc); - - WriteStreamFactory factory = WriteStreamFactory.of(rpc, WriterSettings.newBuilder() - .setTopicPath("/test/topic") - .setProducerId("producer-1") - .setDirectWrite(true) - .build()); - - WriteSession.Stream stream = factory.createNewStream("s1"); - Assert.assertTrue(stream instanceof WriteStream.Fail); - Mockito.verify(rpc).writeSession(Mockito.any(GrpcRequestSettings.class)); - - CompletableFuture res = stream.start(null, null); - Assert.assertTrue(res.isDone()); - Assert.assertEquals(Status.of(StatusCode.INTERNAL_ERROR), res.join()); - stream.close(); // no effect - } - - @Test - public void directWriteByProducerIdProbeUnexpectedResponseTest() { - TopicRpc rpc = Mockito.mock(TopicRpc.class); - - GrpcReadWriteStream probeGrpc = mockGrpcStream(); - - FromServer initResponse = FromServer.newBuilder() - .setStatus(StatusCodesProtos.StatusIds.StatusCode.SUCCESS) - .setUpdateTokenResponse(YdbTopic.UpdateTokenResponse.newBuilder().build()) - .build(); - mockStreamResponse(probeGrpc, initResponse); - - Mockito.when(rpc.writeSession(Mockito.any(GrpcRequestSettings.class))).thenReturn(probeGrpc); - - WriteStreamFactory factory = WriteStreamFactory.of(rpc, WriterSettings.newBuilder() - .setTopicPath("/test/topic") - .setProducerId("producer-1") - .setDirectWrite(true) - .build()); - - WriteSession.Stream stream = factory.createNewStream("s1"); - Assert.assertTrue(stream instanceof WriteStream.Fail); - Mockito.verify(rpc).writeSession(Mockito.any(GrpcRequestSettings.class)); - - CompletableFuture res = stream.start(null, null); - Assert.assertTrue(res.isDone()); - Issue issue = Issue.of("Unexpected message from stream with producer producer-1", Issue.Severity.ERROR); - Assert.assertEquals(Status.of(StatusCode.BAD_REQUEST, issue), res.join()); - stream.close(); // no effect - } - - @Test - public void directWriteByProducerIdPartitionNotFoundTest() { - TopicRpc rpc = Mockito.mock(TopicRpc.class); - - GrpcReadWriteStream probeGrpc = mockGrpcStream(); -// GrpcReadWriteStream actualGrpc = mockGrpcStream(); - - FromServer initResponse = FromServer.newBuilder() - .setStatus(StatusCodesProtos.StatusIds.StatusCode.SUCCESS) - .setInitResponse(YdbTopic.StreamWriteMessage.InitResponse.newBuilder() - .setLastSeqNo(0) - .setPartitionId(5L) - .setSessionId("session") - .build()) - .build(); - - mockStreamResponse(probeGrpc, initResponse); - mockDescribeResult(rpc, partition(1L, 55), partition(2L, 55)); - - Mockito.when(rpc.writeSession(Mockito.any(GrpcRequestSettings.class))).thenReturn(probeGrpc); - - WriteStreamFactory factory = WriteStreamFactory.of(rpc, WriterSettings.newBuilder() - .setTopicPath("/test/topic") - .setProducerId("producer-1") - .setDirectWrite(true) - .build()); - - WriteSession.Stream stream = factory.createNewStream("s1"); - Assert.assertTrue(stream instanceof WriteStream.Fail); - CompletableFuture res = stream.start(null, null); - Assert.assertTrue(res.isDone()); - Status expected = Status.of(StatusCode.BAD_REQUEST, Issue.of("Cannot find partition 5", Issue.Severity.ERROR)); - Assert.assertEquals(expected, res.join()); - } }