4
\$\begingroup\$

I'm having fun with shared pointers in a splay tree. Please let me know what I can do about code readability, and if you have any tips on speeding up my code, let know about that too (however, I am aware of unique_ptr and will implement that next; this is for practice).

*edit: I removed keys from the node class (remnant from another project), since it wasn't being used.

main.cpp

#include <iostream>
#include "splay.h"

int main(){
  auto i = { 7, 8, 19, 10, 15, 9, 14, 12, 11, 23};

  splay<int>* spl = new splay<int>;
  for (auto const& e : i)
    spl->insert(e);

  std::cout<<"Splay Tree\n"<<spl->print()<<std::endl;
  auto number = 10;
  if (spl->find(number))
    number = 14;
  else 
    std::cout<<"not found\n";

  std::cout<<"removing "<<number<<std::endl;
  spl->remove(number);
  std::cout<<spl->print()<<std::endl;

return 0;
}

splay.h

#ifndef SPLAY_H
#define SPLAY_H
#include <memory>
template <class V>
class node{
  public:
    node(V const& value):
          value_ (value) 
        { left = right =  nullptr; }
    std::shared_ptr< node<V> > left, right;
    std::weak_ptr< node<V> > parent;
    V value()
        {return value_;}
  private:
    V value_;
};


template <class T>
class splay
{
  public:
    typedef node<T> node;
    splay()
        { null_=nullptr; };

    void insert (const T& value) 
        { insertNode_ (head_, null_, value); };
    void remove (const T& value)
        { markNode_ (head_, value); };
    bool find (const T& value)
        { return searchNode_ (head_, value); };
    std::string print ()
        { return print_inOrder_ (head_, ""); };

  private:
    std::shared_ptr<node> head_; 
    std::shared_ptr<node> null_;

    void insertNode_ (std::shared_ptr<node>& current,  //recursive insert
        const std::shared_ptr<node>& parent,const T& value)
    {
      if (current == nullptr){ 
        current = std::make_shared<node> (value);
        current->parent = parent;
        if (current->parent.lock()!= null_)
      if (current != head_) 
         splayNode(current);
      }
      else if (current->value() < value)
        insertNode_ (current->right, current, value);
      else if (current->value() > value)
        insertNode_ (current->left , current, value);
    }

    void markNode_ (std::shared_ptr<node>& current,const T& value)
    {//recursively finds the  value, calls markRemove to remove the node
      if (current != nullptr) {
        if (current->value() < value)
          markNode_ (current->right, value);
        else if (current->value() > value)
          markNode_ (current->left , value);
        else 
          removeNode_ (current);
      }
    }

    void removeNode_ (std::shared_ptr<node>& current) //deletes node
    {
      if (current->right == nullptr) {
        if (current->left != nullptr)
          current->left->parent = current->parent;
        current = std::move(current->left);
      }
      else if (current->left == nullptr)  {
        current->left->parent = current->parent;
        current = std::move(current->right);
      }
      else  {                                       //worst case scenario
        std::shared_ptr<node> temp = current->right;
        while (temp->left != nullptr)
          temp = temp->left;
        temp->left = current->left;
        if (current->left != nullptr) {
          temp->left = current->left;
          temp->left->parent = temp;
          if (temp->right != nullptr)
            temp->right->parent= temp->parent;
          temp->parent.lock()->left = temp->right;
          temp->right=current->right;
          temp->right->parent=temp;
          if (current->parent.lock() == null_)
            head_ = temp;
          else if (current == current->parent.lock()->left)
            current->parent.lock()->left = temp;
          else
            current->parent.lock()->right= temp;
          current= std::move(temp);
        }
      }   
    }

    bool searchNode_ (const std::shared_ptr<node>& current, const T& value) 
    { //searches for node.
      if (current != nullptr) {
        if (current->value() < value)
          return searchNode_ (current->right, value);
        else if (current->value() > value)
          return searchNode_ (current->left , value);
        else  {
          splayNode(current);
          return true;
        }
      }
      return false;
    }

    void splayNode (const std::shared_ptr<node> current) {
      while (current != head_){
        if (head_ == current->parent.lock())  {
          if (head_->left == current)
            RR(current);
          else
            LR(current);
        }
        else  {
          if (current->parent.lock()->parent.lock()->left == 
              current->parent.lock())  {
            if (current->parent.lock()->left == current)  {
              RR(current->parent.lock());
              RR(current);
            }
            else  {
              LR(current);
              RR(current);
            }
          }
          else  {
            if (current->parent.lock()->right == current)  {
              LR(current->parent.lock());
              LR(current);
            }
            else  {
              RR(current);
              LR(current);
            }
          }
        }
      }
    }

    void LR(const std::shared_ptr<node> current)
    {
      current->parent.lock()->right = current->left;
      if (current->parent.lock()->right != nullptr)
        current->parent.lock()->right->parent = current->parent.lock();
      current->left = current->parent.lock();
      current->parent = current->left->parent;
      if (current->left->parent.lock() == null_)
        head_ = current;
      else if (current->left == current->left->parent.lock()->left)
        current->left->parent.lock()->left = current;
      else
        current->left->parent.lock()->right= current;
      current->left->parent = current; 
    };

    void RR(const std::shared_ptr<node> current)
    {
      current->parent.lock()->left = current->right;
      if (current->parent.lock()->left != nullptr)
        current->parent.lock()->left->parent = current->parent.lock();
      current->right = current->parent.lock();
      current->parent = current->right->parent;
      if (current->right->parent.lock() == null_)
        head_ = current;
      else if (current->right == current->right->parent.lock()->left)
        current->right->parent.lock()->left = current;
      else 
        current->right->parent.lock()->right= current;
      current->right->parent = current;
    };

    std::string print_inOrder_ (const std::shared_ptr<node>& current, 
        std::string print)
    {
      if (current != nullptr) {
        print+= "["+ std::to_string(current->value())+ "]";
        print = print_inOrder_ (current->left, print);
        print = print_inOrder_ (current->right, print);
      }
      return print;
    }
};

#endif
\$\endgroup\$

1 Answer 1

3
\$\begingroup\$

splay.h uses std::string but does not include <string>. That's easily fixed.


I don't think the node class needs to stand alone, or even be public.

I'd write:

template <class T>
class splay
{
    struct node
    {
        std::shared_ptr<node> left = {};
        std::shared_ptr<node> right = {};
        std::weak_ptr<node> parent = {};
        T value;

        explicit node(T value)
            : value{std::move(value)}
        { }
    };

The null_ member seems to never be modified, and the same for all instances, so it could be static. Better, though, is to eliminate it altogether and just use nullptr or {}.


We have stray semicolons after each of these functions, which just looks sloppy:

    void insert (const T& value) 
        { insertNode_ (head_, nullptr, value); };
    void remove (const T& value)
        { markNode_ (head_, value); };
    bool find (const T& value)
        { return searchNode_ (head_, value); };
    std::string print ()
        { return print_inOrder_ (head_, ""); };

A more serious problem is that they require T to be copy-constructible. It's better to accept value by value and pass as a forwarding reference. It's worth writing some tests that use a move-only type for this.

And print() is a problem compared to a more conventional operator<<(std::ostream&, const splay&) because it requires the whole representation to be built up in memory before even the start of it can be used.


This code is hard to read because it's strangely indented:

        if (current->parent.lock()!= null_)
      if (current != head_) 
         splayNode(current);

The current != head test looks wasteful anyway, because splayNode() does nothing in that case anyway.

The repeated use of .lock() here (and elsewhere) makes the code hard to read. We can make the code much more readable by creating names for some of these found nodes - e.g.

        auto const parent = current->parent.lock();
        auto const left_parent = current->left->parent.lock();

insert_node just silently ignores values that compare equal to existing elements, making the tree a kind of set. That's not clear in the description, and I don't think it's desirable. I'd change the else if (current->value > value) to plain else to fix that.


This comment is unhelpful:

    {//recursively finds the  value, calls markRemove to remove the node

I couldn't find markRemove.

I would probably put the equality test before the < test, because == can be more efficient for many types.


removeNode() doesn't seem to re-splay the tree. We would normally expect the parent of the removed node to be the new head of tree


A handful of places in LR() and RR() unnecessarily copy shared pointers that are about to be overwritten. Better to move-assign, which saves having to increment and decrement the use-count.


Printing only works for types that are acceptable arguments to std::to_string(). We could support user-defined types if we enable argument-dependent-lookup. But it's better to switch to streaming with << - we can use std::ostringstream if we really need to make a string.


Having taken proper care with pointers in the splay tree, it seems ironic to see main() owning a raw pointer which it fails to delete:

  splay<int>* spl = new splay<int>;

There's no need for a pointer here; simply use a local object:

  splay<int> spl;

I think the interface could use more const. If we look up an element, that doesn't change the externally-visible state of the object (even though it may have rotated nodes internally), so I would make the head pointer mutable to permit this.


Modified code

I've renamed a few things for clarity and/or consistency with standard containers. There's still a lot of work that can be done (iterators would be a worthwhile start, for example - the good news is that those can simply wrap a node*, as the only invalidation would of an iterator whose node is removed).

It would be nice to have a constructor that accepts an initializer-list, so we don't have to add values one at a time.

#include <cassert>
#include <memory>
#include <ostream>

template <class T>
class splay
{
    struct node
    {
        std::shared_ptr<node> left = {};
        std::shared_ptr<node> right = {};
        std::weak_ptr<node> parent = {};
        T value;

        explicit node(T value)
            : value{std::move(value)}
        { }
    };

    static void print_node(std::ostream& os, std::shared_ptr<node> const& n)
    {
        if (!n) {
            return;
        }
        print_node(os, n->left);
        os << n->value << ',';
        print_node(os, n->right);
    }

    mutable std::shared_ptr<node> head = {};

public:
    void insert(T value)
    {
        insertNode(head, nullptr, std::move(value));
    }
    void remove(const T& value)
    {
        markNode(head, value);
    }
    bool contains(const T& value) const
    {
        return searchNode(head, value);
    }

    friend auto& operator<<(std::ostream& os, const splay& tree)
    {
        os << '[';
        print_node(os, tree.head);
        return os << ']';
    }

private:
    void check_invariant() const
    {
#ifndef NDEBUG
        check_invariant(nullptr, head.get());
#endif
    }

    void check_invariant(const node *parent, const node *p) const
    {
        if (!p) {
            return;
        }
        assert(p->parent.lock().get() == parent);
        if (p->left) {
            assert(p->left->value < p->value);
            check_invariant(p, p->left.get());
        }
        if (p->right) {
            assert(p->value < p->right->value);
            check_invariant(p, p->right.get());
        }
    }

    void insertNode(std::shared_ptr<node>& current,
                    const std::shared_ptr<node>& parent,
                    T value)
    {
        if (!current) {
            current = std::make_shared<node>(std::move(value));
            current->parent = parent;
            if (parent) {
                splayNode(current);
            }
            return;
        }
        insertNode(current->value < value ? current->right : current->left, current, std::move(value));
    }

    void markNode(std::shared_ptr<node>& current, const T& value)
    {
        if (!current) {
            return;
        }
        if (current->value == value) {
            removeNode(current);
            return;
        }
        markNode(current->value < value ? current->right : current->left, value);
    }

    void removeNode(std::shared_ptr<node>& current)
    {
        if (current->right && current->left) {
            std::shared_ptr<node> temp = current->right;
            while (temp->left) {
                temp = temp->left;
            }
            temp->left = current->left;
            if (current->left) {
                temp->left = current->left;
                temp->left->parent = temp;
                if (temp->right) {
                    temp->right->parent = temp->parent;
                }
                temp->parent.lock()->left = temp->right;
                temp->right=current->right;
                temp->right->parent=temp;
                auto const parent = current->parent.lock();
                if (!parent)
                    head = temp;
                else if (current == parent->left)
                    parent->left = temp;
                else
                    parent->right= temp;
                current= std::move(temp);
            }
            return;
        }

        if (current->left) {
            current->left->parent = current->parent;
        } else if (current->right) {
            current->left->parent = current->parent;
            current = std::move(current->right);
        } else {
            current = nullptr;
        }
    }

    bool searchNode(const std::shared_ptr<node>& current, const T& value) const
    {
        if (!current) {
            return false;
        }
        if (current->value == value) {
            splayNode(current);
            return true;
        }
        return searchNode(current->value < value ? current->right : current->left, value);
    }

    void splayNode(const std::shared_ptr<node> current) const
    {
        while (true) {
            //check_invariant();  // useful when debugging
            auto const parent = current->parent.lock();
            if (!parent) {
                return;
            }
            auto const grandparent = parent->parent.lock();
            if (!grandparent) {
                if (parent->left == current) {
                    rotate_right(current);
                } else {
                    rotate_left(current);
                }
            } else {
                if (grandparent->left == parent) {
                    if (parent->left == current) {
                        rotate_right(parent);
                    } else {
                        rotate_left(current);
                    }
                    rotate_right(current);
                } else {
                    if (parent->right == current) {
                        rotate_left(parent);
                    } else {
                        rotate_right(current);
                    }
                    rotate_left(current);
                }
            }
        }
    }

    void rotate_left(const std::shared_ptr<node> current) const
    {
        auto const parent = current->parent.lock();
        auto const grandparent = parent->parent.lock();
        parent->right = std::move(current->left);
        if (parent->right) {
            parent->right->parent = parent;
        }
        current->left = parent;
        current->parent = grandparent;
        auto& dest = !grandparent ? head
            : current->left == grandparent->left ? grandparent->left
            : grandparent->right;
        current->left->parent = dest = std::move(current);
    }

    void rotate_right(const std::shared_ptr<node> current) const
    {
        auto const parent = current->parent.lock();
        auto const grandparent = parent->parent.lock();
        parent->left = std::move(current->right);
        if (parent->left) {
            parent->left->parent = parent;
        }
        current->right = std::move(parent);
        current->parent = grandparent;
        auto& dest = !grandparent ? head
            : current->right == grandparent->left ? grandparent->left
            : grandparent->right;
        current->right->parent = dest = std::move(current);
    }
};

I think it becomes easier to understand if we make the node object responsible for maintaining the invariant that the node's parent always points back to its owner. And also, extract out some of the repeated operations, such as finding where to reparent to. And during development, it's helpful to have a test of this invariant and the ordering invariant.


#include <cassert>
#include <memory>
#include <ostream>

template <class T>
class splay
{
    class node : public std::enable_shared_from_this<node>
    {
        using std::enable_shared_from_this<node>::shared_from_this;

        std::shared_ptr<node> l = {};
        std::shared_ptr<node> r = {};
        std::weak_ptr<node> p = {};
    public:
        T value;

        explicit node(T value)
            : value{std::move(value)}
        { }

        void check_invariant() const
        {
            if (l) {
                if (l->parent().get() != this) {
                }
                assert(l->parent().get() == this);
                assert(l->value < value);
                l->check_invariant();
            }
            if (r) {
                if (r->parent().get() != this) {
                }
                assert(r->parent().get() == this);
                assert(value < r->value);
                r->check_invariant();
            }
        }

        const std::shared_ptr<node>& left() const
        {
            return l;
        }

        const std::shared_ptr<node>& right() const
        {
            return r;
        }

        std::shared_ptr<node> parent() const
        {
            return p.lock();
        }

        bool is_left_child() const
        {
            auto p = parent();
            return p && p->l.get() == this;
        }

        void orphan()
        {
            p.reset();
        }

        void set_left(std::shared_ptr<node> n)
        {
            if (n) {
                n->p = shared_from_this();
            }
            l = std::move(n);
        }

        void set_right(std::shared_ptr<node> n)
        {
            if (n) {
                n->p = shared_from_this();
            }
            r = std::move(n);
        }
    };

    static void print_node(std::ostream& os, std::shared_ptr<node> const& n)
    {
        if (!n) {
            return;
        }
        print_node(os, n->left());
        os << n->value << ',';
        print_node(os, n->right());
    }

    mutable std::shared_ptr<node> head = {};

public:
    void insert(T value)
    {
        auto n = std::make_shared<node>(std::move(value));
        if (head) {
            insertNode(head, n);
            splayNode(n);
        } else {
            head = std::move(n);
        }
    }
    void remove(const T& value)
    {
        auto node = searchNode(head, value);
        removeNode(node);
    }
    bool contains(const T& value) const
    {
        auto node = searchNode(head, value);
        splayNode(node);
        return node != nullptr;
    }

    friend auto& operator<<(std::ostream& os, const splay& tree)
    {
        os << '[';
        print_node(os, tree.head);
        return os << ']';
    }

private:
    void check_invariant() const
    {
#ifndef NDEBUG
        if (head) {
            head->check_invariant();
        }
#endif
    }

    void insertNode(const std::shared_ptr<node>& current, std::shared_ptr<node> n)
    {
        if (n->value < current->value) {
            if (current->left()) {
                insertNode(current->left(), std::move(n));
            } else {
                current->set_left(std::move(n));
            }
        } else {
            if (current->right()) {
                insertNode(current->right(), std::move(n));
            } else {
                current->set_right(std::move(n));
            }
        }
    }

    void removeNode(std::shared_ptr<node>& current)
    {
        if (!current) {
            return;
        }

        auto const parent = current->parent();
        if (current->right() && current->left()) {
            std::shared_ptr<node> temp = current->right();
            while (temp->left()) {
                temp = temp->left();
            }
            temp->set_left(current->left());
            replace_node(current, current->right());
        } else if (current->left()) {
            replace_node(current, current->left());
        } else if (current->right()) {
            replace_node(current, current->right());
        } else {
            replace_node(current, nullptr);
        }

        splayNode(parent);
    }

    std::shared_ptr<node> searchNode(const std::shared_ptr<node>& current, const T& value) const
    {
        if (!current) {
            return current;
        }
        if (current->value == value) {
            return current;
        }
        return searchNode(current->value < value ? current->right() : current->left(), value);
    }

    void splayNode(const std::shared_ptr<node> current) const
    {
        return;
        if (!current) {
            return;
        }

        while (true) {
            auto const parent = current->parent();
            if (!parent) {
                assert(head == current);
                return;
            }
            auto const grandparent = parent->parent();
            if (!grandparent) {
                if (current->is_left_child()) {
                    rotate_right(current);
                } else {
                    rotate_left(current);
                }
            } else {
                if (parent->is_left_child()) {
                    if (current->is_left_child()) {
                        rotate_right(parent);
                    } else {
                        rotate_left(current);
                    }
                    rotate_right(current);
                } else {        // parent is right child
                    if (current->is_left_child()) {
                        rotate_right(current);
                    } else {
                        rotate_left(parent);
                    }
                    rotate_left(current);
                }
            }
        }
    }

    void rotate_left(const std::shared_ptr<node> current) const
    {
        auto const parent = current->parent();
        auto const grandparent = parent->parent();
        parent->set_right(current->left());
        replace_node(parent, current);
        current->set_left(std::move(parent));
    }

    void rotate_right(const std::shared_ptr<node> current) const
    {
        auto const parent = current->parent();
        auto const grandparent = parent->parent();
        parent->set_left(current->right());
        replace_node(parent, current);
        current->set_right(std::move(parent));
    }

    void replace_node(const std::shared_ptr<node>& current,
                      std::shared_ptr<node> replacement) const
    {
        auto const parent = current->parent();
        if (!parent) {
            set_head(std::move(replacement));
        } else if (current->is_left_child()) {
            parent->set_left(std::move(replacement));
        } else {
            parent->set_right(std::move(replacement));
        }
    }

    void set_head(std::shared_ptr<node> n) const
    {
        if (n) {
            n->orphan();
        }
        head = std::move(n);
    }
};

Although this is somewhat longer overall, I think it makes the tree-manipulation algorithm easier to read.

\$\endgroup\$

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.