/*
 * Copyright 2016 gRPC authors.
 * Copyright 2021 Bloomberg Finance LP
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef INCLUDED_BUILDBOXCOMMON_GRPCTESTSERVER
#define INCLUDED_BUILDBOXCOMMON_GRPCTESTSERVER

#include <chrono>
#include <grpcpp/support/byte_buffer.h>
#include <memory>
#include <stop_token>
#include <thread>
#include <typeindex>
#include <unordered_map>

#include <google/protobuf/message.h>
#include <google/protobuf/util/message_differencer.h>
#include <grpcpp/generic/async_generic_service.h>
#include <grpcpp/grpcpp.h>
#include <gtest/gtest.h>

namespace buildboxcommon {

/**
 * An in-process gRPC server to test client code. This serves a similar purpose
 * as mock client stubs, however, this server works with both sync and async
 * client code.
 *
 * Construct a `GrpcTestServerContext` instance for each to be mocked gRPC
 * method call. As the `GrpcTestServerContext` methods are blocking, a
 * separate thread should be used.
 */
class GrpcTestServer {
  public:
    GrpcTestServer()
    {
        grpc::ServerBuilder builder;
        builder.RegisterAsyncGenericService(&d_genericService);
        builder.AddListeningPort("localhost:0",
                                 grpc::InsecureServerCredentials(), &d_port);
        d_cq = builder.AddCompletionQueue();
        d_server = builder.BuildAndStart();
        d_channel = d_server->InProcessChannel(grpc::ChannelArguments());
    }

    virtual ~GrpcTestServer()
    {
        d_server->Shutdown();
        d_cq->Shutdown();
        // Drain queue
        void *ignored_tag = nullptr;
        bool ignored_ok = false;
        while (d_cq->Next(&ignored_tag, &ignored_ok)) {
        }
    }

    /**
     * Return the in-process gRPC channel, which test clients can use to
     * call gRPC methods on the test server.
     */
    std::shared_ptr<grpc::Channel> channel() { return d_channel; }

    /**
     * Return the URL, which out-of-process test clients can use to
     * call gRPC methods on the test server.
     */
    std::string url() { return "http://localhost:" + std::to_string(d_port); }

    friend class GrpcTestServerContext;
    friend class GrpcTestServerGenericContext;
    friend class BaseGrpcTestServerContext;

  protected:
    std::unique_ptr<grpc::Server> d_server;
    std::unique_ptr<grpc::ServerCompletionQueue> d_cq;
    std::shared_ptr<grpc::Channel> d_channel;
    grpc::AsyncGenericService d_genericService;
    int d_port;
};

class BaseGrpcTestServerContext {
  protected:
    BaseGrpcTestServerContext(GrpcTestServer *server) : d_server(server)
    {
        d_cq = d_server->d_cq.get();
    }

    void initializeStream(grpc::GenericServerContext *ctx)
    {
        d_stream = std::make_unique<grpc::GenericServerAsyncReaderWriter>(ctx);
    }

  public:
    void next()
    {
        void *tag;
        bool ok;
        EXPECT_TRUE(d_cq->Next(&tag, &ok) && ok);
    }

    /**
     * Wait for and read a message from the client and verify that it matches
     * the expected message.
     */
    template <typename R>
    void
    read(const R &expectedRequest,
         google::protobuf::util::MessageDifferencer::RepeatedFieldComparison
             repeatedFieldComparison = google::protobuf::util::
                 MessageDifferencer::RepeatedFieldComparison::AS_LIST)
    {
        grpc::ByteBuffer inBuffer;
        R inMsg;
        d_stream->Read(&inBuffer, nullptr);
        next();
        EXPECT_TRUE(parseFromByteBuffer(&inBuffer, &inMsg));

        std::string diff;
        google::protobuf::util::MessageDifferencer differencer;
        differencer.set_repeated_field_comparison(repeatedFieldComparison);
        differencer.ReportDifferencesToString(&diff);
        if (!differencer.Compare(expectedRequest, inMsg)) {
            FAIL() << "Unexpected differences in received request:\n" << diff;
        }
    }

    /**
     * Send a message back to the client and keep the stream open.
     */
    template <typename W> void write(const W &response)
    {
        auto outBuffer = serializeToByteBuffer(response);
        d_stream->Write(*outBuffer, nullptr);
        next();
    }

    /**
     * Send a message back to the client and finish the stream with an
     * `OK` status.
     */
    template <typename W> void writeAndFinish(const W &response)
    {
        auto outBuffer = serializeToByteBuffer(response);
        d_stream->WriteAndFinish(*outBuffer, grpc::WriteOptions(),
                                 grpc::Status::OK, nullptr);
        next();
    }

    /**
     * Finish the stream with the specified status.
     */
    void finish(const grpc::Status &status)
    {
        d_stream->Finish(status, nullptr);
        next();
    }

    /**
     * Expect client cancelling.
     */
    void expectCancel()
    {
        void *tag;
        bool ok;
        d_stream->Finish(grpc::Status::OK, nullptr);
        EXPECT_TRUE(d_cq->Next(&tag, &ok));
        EXPECT_FALSE(ok);
    }

    static bool parseFromByteBuffer(grpc::ByteBuffer *buffer,
                                    google::protobuf::Message *message)
    {
        std::string s;
        std::vector<grpc::Slice> slices;
        buffer->Dump(&slices);
        for (const auto &slice : slices) {
            s.append(reinterpret_cast<const char *>(slice.begin()),
                     slice.size());
        }
        return message->ParseFromString(s);
    }

    static std::unique_ptr<grpc::ByteBuffer>
    serializeToByteBuffer(const google::protobuf::Message &message)
    {
        std::string buf;
        message.SerializeToString(&buf);
        grpc::Slice slice(buf);
        return std::make_unique<grpc::ByteBuffer>(&slice, 1);
    }

    GrpcTestServer *d_server;
    grpc::ServerCompletionQueue *d_cq;
    std::unique_ptr<grpc::GenericServerAsyncReaderWriter> d_stream;
};

/*
 * A gRPC server context that waits for a specific method call.
 * Once a method is received, the caller will be able to manually
 * read from and write to the stream.
 */
class GrpcTestServerContext : public BaseGrpcTestServerContext {
  public:
    // Allow access to underlying context to enable inspecting/returning
    // metadata and other per-request information from tests
    grpc::GenericServerContext d_ctx;

    /**
     * Create a server context and wait for the client to call the specified
     * method. The method name must be fully qualified with the package name
     * and the service name.
     */
    GrpcTestServerContext(GrpcTestServer *server, const std::string &method)
        : BaseGrpcTestServerContext(server)
    {
        initializeStream(&d_ctx);
        d_server->d_genericService.RequestCall(&d_ctx, d_stream.get(), d_cq,
                                               d_cq, nullptr);
        next();
        EXPECT_EQ(d_ctx.method(), method);
    }
};

/*
 * A generalized version of the GrpcTestServerContext.
 * Instead of waiting to Read/Write a single method with a fixed request
 * and response, multiple methods can be registered with different request/
 * response pairs. Then the server can poll for incoming requests and respond
 * accordingly.
 */
class GrpcTestServerGenericContext : public BaseGrpcTestServerContext {
  public:
    struct RequestResponsePair {
        std::unique_ptr<google::protobuf::Message> request;
        std::unique_ptr<google::protobuf::Message> response;

        template <typename Req, typename Resp>
        RequestResponsePair(const Req &req, const Resp &resp)
            : request{std::make_unique<Req>(req)},
              response{std::make_unique<Resp>(resp)}
        {
        }
    };

    // Store registered mock methods with their request/response pairs
    std::unordered_map<std::string, std::vector<RequestResponsePair>>
        d_mockMethodRegistry;

    // Store handler function for each method
    using MethodHandler = std::function<void(const std::string &)>;
    std::unordered_map<std::string, MethodHandler> d_methodHandlers;

    // Store type information for each registered method
    struct MethodTypeInfo {
        std::type_index requestType;
        std::type_index responseType;

        MethodTypeInfo(std::type_index req, std::type_index resp)
            : requestType(req), responseType(resp)
        {
        }
    };

    std::unordered_map<std::string, MethodTypeInfo> d_methodTypes;

    // Default timeout for polling completion queue
    static constexpr std::chrono::milliseconds DEFAULT_POLL_TIMEOUT{100};

    /**
     * Create a server context and wait for the client to call a method.
     * The server context allows user to register expected requests and
     * will output the corresponding responses.
     */
    GrpcTestServerGenericContext(GrpcTestServer *server)
        : BaseGrpcTestServerContext(server)
    {
    }

    /*
     * Poll for a single incoming request and handle it. Polling is
     * interruptible via timeouts and periodically checks whether it should
     * continue accepting requests.
     */
    void poll_one()
    {
        if (!d_acceptingRequests) {
            return;
        }

        // Requires a fresh context for each incoming request
        grpc::GenericServerContext ctx;
        initializeStream(&ctx);

        d_server->d_genericService.RequestCall(&ctx, d_stream.get(), d_cq,
                                               d_cq, nullptr);

        void *tag;
        bool ok;

        while (d_acceptingRequests) {
            auto deadline =
                std::chrono::system_clock::now() + DEFAULT_POLL_TIMEOUT;
            auto status = d_cq->AsyncNext(&tag, &ok, deadline);

            if (status == grpc::CompletionQueue::TIMEOUT) {
                continue;
            }
            else if (status == grpc::CompletionQueue::GOT_EVENT) {
                break;
            }
            else {
                // The completion queue is shutting down
                return;
            }
        }

        // Completion queue might've gotten a request that is not ok
        // This can happen during shutdown - in which case we just ignore it
        if (d_acceptingRequests && ok) {
            handleMethod(ctx.method());
        }
    }

    void stopPolling() { d_acceptingRequests = false; }

    /*
     * Register a method with its request and response types.
     * Must be called before adding mock responses for this method.
     */
    template <typename Req, typename Resp>
    void registerMethod(const std::string &method)
    {
        if (isMethodRegistered(method)) {
            if (!isValidMethodType<Req, Resp>(method)) {
                throw std::runtime_error(
                    "Method already registered with different "
                    "request/response types: " +
                    method);
            }
            return;
        }

        d_methodTypes.emplace(method,
                              MethodTypeInfo(typeid(Req), typeid(Resp)));

        d_methodHandlers[method] = [this](const std::string &methodName) {
            this->handleMethodTyped<Req>(methodName);
        };
    }

    /*
     * Add a mock request/response pair for a registered method.
     * The method must be registered first via registerMethod().
     */
    template <typename Req, typename Resp>
    void addMockMethodCall(const std::string &method,
                           const Req &expectedRequest, const Resp &response)
    {
        if (!isMethodRegistered(method)) {
            throw std::runtime_error("Method not registered: " + method);
        }

        if (!isValidMethodType<Req, Resp>(method)) {
            throw std::runtime_error("Request/response types don't match "
                                     "registered method schema: " +
                                     method);
        }

        if (auto existing =
                tryFindMethodRequestResponsePair(method, expectedRequest);
            existing.has_value()) {
            auto existingResponse = existing.value().get().response.get();
            if (compareMessages(*existingResponse, response)) {
                return;
            }

            throw std::runtime_error(
                "Method with identical request already registered: " + method);
        }

        d_mockMethodRegistry[method].emplace_back(
            RequestResponsePair(expectedRequest, response));
    }

  private:
    // Control flag to accept or reject incoming requests
    std::atomic_bool d_acceptingRequests{true};

    template <typename Req, typename Resp>
    [[nodiscard]] bool isValidMethodType(const std::string &method) const
    {
        if (auto it = d_methodTypes.find(method); it != d_methodTypes.end()) {
            const auto &[_, typeInfo] = *it;
            if (typeInfo.requestType != typeid(Req) ||
                typeInfo.responseType != typeid(Resp)) {
                return false;
            }
        }

        return true;
    }

    [[nodiscard]]
    inline bool isMethodRegistered(const std::string &method) const
    {
        return d_methodTypes.contains(method);
    }

    void handleMethod(const std::string &method)
    {
        if (!d_mockMethodRegistry.contains(method)) {
            finish(grpc::Status(grpc::StatusCode::UNIMPLEMENTED,
                                "Unimplemented"));
            return;
        }

        d_methodHandlers[method](method);
    }

    template <typename Req> void handleMethodTyped(const std::string &method)
    {
        grpc::ByteBuffer inBuffer;
        d_stream->Read(&inBuffer, nullptr);
        next();

        auto inMsg = std::make_unique<Req>();
        EXPECT_TRUE(parseFromByteBuffer(&inBuffer, inMsg.get()));

        auto responseRequestPair =
            tryFindMethodRequestResponsePair(method, *inMsg);
        if (!responseRequestPair.has_value()) {
            finish(grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
                                "Invalid argument"));
            return;
        }

        auto response = responseRequestPair.value().get().response.get();
        writeAndFinish(*response);
    }

    [[nodiscard]]
    static inline bool compareMessages(
        const google::protobuf::Message &expectedRequest,
        const google::protobuf::Message &actualRequest,
        google::protobuf::util::MessageDifferencer::RepeatedFieldComparison
            compare = google::protobuf::util::MessageDifferencer::
                RepeatedFieldComparison::AS_LIST)
    {
        std::string diff;
        google::protobuf::util::MessageDifferencer differencer;
        differencer.set_repeated_field_comparison(compare);
        differencer.ReportDifferencesToString(&diff);

        return differencer.Compare(expectedRequest, actualRequest);
    }

    [[nodiscard]]
    std::optional<std::reference_wrapper<const RequestResponsePair>>
    tryFindMethodRequestResponsePair(const std::string &method,
                                     const google::protobuf::Message &msg)
    {
        if (!d_mockMethodRegistry.contains(method)) {
            return std::nullopt;
        }

        for (const auto &requestResponsePair : d_mockMethodRegistry[method]) {
            auto expectedRequest = requestResponsePair.request.get();

            if (compareMessages(*expectedRequest, msg)) {
                return std::cref(requestResponsePair);
            }
        }

        return std::nullopt;
    }
};

class GrpcGenericTestServer : public GrpcTestServer {
  public:
    GrpcGenericTestServer() : GrpcTestServer()
    {
        d_ctx = std::make_unique<GrpcTestServerGenericContext>(this);
    }

    void startAsyncServer()
    {
        d_serverThread = std::jthread([this](std::stop_token stopToken) {
            while (!stopToken.stop_requested()) {
                d_ctx->poll_one();
            }
        });
    }

    void stopServer()
    {
        d_ctx->stopPolling();
        d_serverThread.request_stop();

        d_server->Shutdown();
        d_cq->Shutdown();

        if (d_serverThread.joinable()) {
            d_serverThread.join();
        }

        void *ignored_tag = nullptr;
        bool ignored_ok = false;
        while (d_cq->Next(&ignored_tag, &ignored_ok)) {
        }
    }

    std::unique_ptr<GrpcTestServerGenericContext> d_ctx;
    std::unique_ptr<grpc::GenericServerAsyncReaderWriter> d_stream;
    std::jthread d_serverThread;
};

} // namespace buildboxcommon

#endif
