Shared pointer with predefined maximum use count

156 views Asked by At

Context: I have a queue that supports single-read/single-write from two different(/or not) threads, to enforce this behaviour i.e. single-reader/single-writer at a time I need to limit the number of the threads owning the queue at a time to 2 (the writer already owns the queue), I was thinking then of creating a shared_ptr to the queue with a max ref count known at compile time set to 2. Hence my question is the below.

Question: Is there a way to implement a shared pointer (maybe using unique_pointer's) with a maximum ref count that is known at compile time? My case is max_ref_count = 2 i.e. exceeding ref count limit = 2 should be a compile-time error.

const auto p    = std::make_shared<2, double>(2159); //works fine
const auto q    = p; // works fine
const auto err1 = p; // does not compile
const auto err2 = q; // does not compile
2

There are 2 answers

0
Jan Schultke On BEST ANSWER

Impossible at compile-time

What you're doing isn't possible at compile-time, only at runtime. Note that std::shared_ptr is copyable, so it's possible that "copy paths diverge":

std::shared_ptr A = /* ... */;
auto B = A;
auto C = A;

At this point, B and C don't know anything about each other, and there is no way to ensure that they do through the type system. You can at best count the length of the copy path from A to B and subsequent copies, but not the global amount of copies.

This would require some form of stateful metaprogramming and C++ does not support that.

Difficult at run-time

Note two problems:

  1. std::shared_ptr is also expected to have thread-safe counting. This also means that std::shared_ptr::use_count() may not yield the most recent result and cannot be considered a reliable metric.
  2. If you kept track of the use count yourself, you would need to allocate a new int in addition to the atomic counter that shared pointers already have. This is somewhat annoying, but doable.

Single-threaded solution

If you don't consider the problems of multi-threading, you can just use std::shared_ptr::use_count() to keep track of uses. This is a reliable metric in a single-threaded program. You just need to make a wrapper for std::shared_ptr which throws whenever the limited is exceeded, which could happen in the copy constructor and the copy assignment operator.

template <typename T, std::size_t N>
  requires (N != 0)
struct limited_shared_ptr {
    using base_type = std::shared_ptr<T, N>;
    
  private:
    base_type base;

  public:
    // OK
    limited_shared_ptr(T* ptr) : base{ptr} {}

    // The following contains an immediately invoked lambda expression in the copy
    // constructor.
    // This is clunky, but necessary so that we throw *before* we copy-construct
    limited_shared_ptr(const limited_shared_ptr& other)
      : base{[this] -> base_type& {
        if (other.use_count() >= N) {
            throw /* ... */;
        }
        return other;
    }()} {}

    limited_shared_ptr& operator=(const limited_shared_ptr& other) {
        if (this != &other &&
            base.get() != other.get() &&
            other.get() != nullptr &&
            other.use_count() >= N) {
            throw /* ... */;
        }
        base = other.base;
        return *this;
    }

    // Move assignment or destruction cannot increase the use_count(), only
    // keep it constant or decrease it, so we can keep these operations defaulted.
    limited_shared_ptr(limited_shared_ptr&& other) = default;
    limited_shared_ptr& operator=(limited_shared_ptr&&) = default;
    ~limited_shared_ptr() = default;

    // TODO: expose more of the std::shared_ptr interface
};

Multi-threaded solution

This is a little bit more complicating, and I will only provide the general outline.

You can create a std::shared_ptr with a custom deleter. This custom deleter can contain a std::atomic<std::size_t> to keep reliable track of the use count.

struct counting_deleter {
    std::atomic<std::size_t> use_count;
    void operator()(auto* ptr) const {
        delete ptr;
    }
};

With this, you don't need to manage resources yourself, but can let std::shared_ptr do it and access the deleter with base.get_deleter<counting_deleter>() any time.

In the copy constructor, similar to the single-threaded solution, you would check:

counting_deleter* deleter = other.get_deleter<counting_deleter>();
std::size_t expected = deleter->use_count.load(std::memory_order::relaxed);
do {
    if (expected >= N) {
        // increasing the use_count would exceed the limit
        throw /* ... */;
    }
    // otherwise, attempt to increment atomically using a weak CAS
} while (!deleter->use_count.compare_exchange_weak(expected, expected + 1));
// note: the CAS can probably use std::memory_order::relaxed too

Basically, we check if increasing the use count is possible. If not, we throw, otherwise we attempt to increment the use_count with compare_exchange_weak.

The copy assignment operator is analogous. The other special member functions can be defaulted, since moving or destroying cannot blow the limit.

4
DeerSpotter On

The standard std::shared_ptr in C++ doesn't directly provide a mechanism to set a maximum reference count at compile time. However, you can achieve similar behavior by creating a custom smart pointer with a limited reference count. Here's a basic example:

#include <iostream>
#include <memory>
#include <stdexcept>

template <typename T, int MaxRefCount = 2>
class LimitedSharedPtr {
public:
    LimitedSharedPtr(T* ptr) : ptr_(ptr), refCount_(1) {
        if (refCount_ > MaxRefCount) {
            throw std::runtime_error("Exceeded maximum reference count");
        }
    }

    LimitedSharedPtr(const LimitedSharedPtr& other) : ptr_(other.ptr_), refCount_(other.refCount_) {
        if (refCount_ > MaxRefCount) {
            throw std::runtime_error("Exceeded maximum reference count");
        }
        refCount_++;
    }

    LimitedSharedPtr& operator=(const LimitedSharedPtr& other) {
        if (this != &other) {
            release();
            ptr_ = other.ptr_;
            refCount_ = other.refCount_;
            if (refCount_ > MaxRefCount) {
                throw std::runtime_error("Exceeded maximum reference count");
            }
            refCount_++;
        }
        return *this;
    }

    ~LimitedSharedPtr() {
        release();
    }

    T* get() const {
        return ptr_;
    }

private:
    void release() {
        refCount_--;
        if (refCount_ == 0) {
            delete ptr_;
        }
    }

    T* ptr_;
    int refCount_;
};

int main() {
    LimitedSharedPtr<double, 2> p(new double(2159));
    LimitedSharedPtr<double, 2> q = p; // works fine
    LimitedSharedPtr<double, 2> err1 = p; // throws runtime_error
    LimitedSharedPtr<double, 2> err2 = q; // throws runtime_error

    return 0;
}