The visitor pattern: OOP takes on the Expression Problem

The visitor pattern is a behavioral design pattern first popularized in the "Gang of Four" book. It is used to separate operations on types from the types themselves. When using the composite pattern to build an AST from a class heiracrchy, as one would when using the interpreter pattern, every operation to be performed on the AST would require the addition of another method in every class type. To say this is tedious would be a monumental understatement, and is a variant of what's known as the expression problem

Where the interpreter pattern would have us declare a print() method in every class to print the AST, and an eval() method in every expression node for evaluating expressions, etc, the visitor pattern addresses this issue by declaring an abstract "visitor" class for our AST type, which we derive from to create new operations on the structure. Instead of adding new methods to existing classes, we create new types of visitors, such as PrettyPrintVisitor, or InterpreterVisitor.

In todays post I want to discuss implementing the visitor pattern for a simple AST. This is one of those topics where it's easier to show what im talking about, so ill skip the small talk and lets get right in to it.

Grammar, ASTs, and Class Heirarchies

In order to keep things interesting while still being manageable in a single post we will construct an AST for a toy language that supports a few simple constructs like looping, if statements, and print statements, as well as assignment and basic arithmetic operators for expressions. 

The following EBNF-like grammar describes the language we'll be implementing.

    <Program> := <statementList>
    <StatmentList> := <statement> ; {<statement>;}*
    <Statement> := [ <WhileStmt> | <IfStmt> | <PrintStmt> | <ExprStmt> ] ;
    <WhileStmt> := while ( <expression> ) { statementList }
    <IfStmt>    := if ( <expression> ) { statementList } else { statementList }
    <PrintStmt> := println <expression>
    <ExprStmt>  := <expression> ;
    <expression> := <relop> ':=' <relop>
    <relop>   := <term> ( == | != | < | > | <= | >= ) <term>
    <term>    := <factor> (+|-) <factor>
    <factor>  := <val> (*|/) <val>
    <val>    := -<primary>
    <primary> := number | id | string | (<expr>)

As you can see from the grammar, it is a good combination of expressions and statements. And from looking at the grammar one can surmise that even a small language like this one will still have a not-insignifcant number of node types in its AST's class heirarchy.

Nodes Types

When implementing our node structure each node type will declare an accept() method that is used as the "gateway" to the structure. We will use the composite pattern to design a class heriarchy for our node types from which to construct the AST. 

//Base AST Class
class ASTNode {
    private:
        Token token;
    public:
        ASTNode(Token tok) : token(tok) { }
        Token getToken() {
            return token;
        }
        virtual void accept(Visitor* visitor) = 0;
};

//Base Expr Class
class ExpressionNode : public ASTNode {
    public:
        ExpressionNode(Token tok) : ASTNode(tok) { }
};

//Base Stmt Class
class StatementNode : public ASTNode {
    public:
        StatementNode(Token tok) : ASTNode(tok) { }
};

 

The Visitor

The visitor class will be the object resonsible for doing the actual traversal of the AST. The node being visited has an accept() method which takes the visitor as a parameter. Likewise, the visitor has a visit() method which has the node type as a parameter. This has the effect of emulating multiple dispatch, a feature not present in C++'s native single dispatch environment.

The base visitor class will be declared purely virtual, as its an asbtract class that should never be directly instantiated. It is this type that we use as a parameter for the accept() method. Inorder to declare our visitor class, we need to make forward declarations of all the node types that our visitor classes will be visiting. 

//Abstract Visitor Interface
class Visitor {
    public:
        virtual void visit(PrintStatement* ps) = 0;
        virtual void visit(IfStatement* is) = 0;
        virtual void visit(WhileStatement* ws) = 0;
        virtual void visit(ExprStatement* es) = 0;
        virtual void visit(ProgramStatement* ps) = 0;
        virtual void visit(StatementList* sl) = 0;

        virtual void visit(IdExpression* idexpr) = 0;
        virtual void visit(LiteralExpression* litexpr) = 0;
        virtual void visit(AssignExpression* assignExpr) = 0;
        virtual void visit(BinaryExpression* binexpr) = 0;
        virtual void visit(RelOpExpression* relexpr) = 0;
};

Concrete AST Nodes

For each of the node types we implement a concrete class which derive from either StatementNode or ExpressionNode. While we still implement just as many classes as we have to for the interpreter pattern, each of the respective classes is MUCH shorter. This is because aside from their structural properties and accessors, the only method we need for each is the accept() method for the visitor.

First Up, the statement nodes:

class StatementList : StatementNode {
    private:
        list<StatementNode*> statements;
    public:
        StatementList(Token tk, list<StatementNode*> stmts) : StatementNode(tk), statements(stmts) { }
        list<StatementNode*>& getStatements() { return statements; }
        void addStatement(StatementNode* stmt) { statements.push_back(stmt); }
        void accept(Visitor* visit) { visit->visit(this); }
};


class ProgramStatement : public StatementNode {
    private:
        StatementList* statementList;
    public:
        ProgramStatement(Token token) : StatementNode(token) { }
        void setProgram(StatementList* sn) { statementList = sn; }
        StatementList* getStatement() { return statementList; }
        void accept(Visitor* visitor) {
            visitor->visit(this);
        }
};

class PrintStatement : public StatementNode {
    private:
        ExpressionNode* expression;
    public:
        PrintStatement(Token token) : StatementNode(token) { }
        void setExpression(ExpressionNode* expr) { expression = expr; }
        ExpressionNode* getExpression() { return expression; }
        void accept(Visitor* visitor) {
            visitor->visit(this);
        }
};

class WhileStatement : public StatementNode {
    private:
        ExpressionNode* testExpr;
        StatementList* body;
    public:
        WhileStatement(Token token) : StatementNode(token) { }
        void setTestExpr(ExpressionNode* expr) { testExpr = expr; }
        void setLoopBody(StatementList* stmt) { body = stmt; }
        ExpressionNode* getTestExpr() { return testExpr; }
        StatementList* getLoopBody() { return body; }
        void accept(Visitor* visitor) {
            visitor->visit(this);
        }
};

class IfStatement : public StatementNode {
    private:
        ExpressionNode* testExpr;
        StatementNode* trCase;
        StatementNode* faCase;
    public:
        IfStatement(Token token) : StatementNode(token) { }
        void setTestExpr(ExpressionNode* expr) { testExpr = expr; }
        void accept(Visitor* visitor) {
          //  visitor->visit(this);
        }
};

class ExprStatement : public StatementNode {
    private:
        ExpressionNode* expr;
    public:
        ExprStatement(Token token) : StatementNode(token) { }
        void setExpr(ExpressionNode* expression) { expr = expression; }
        ExpressionNode* getExpression() { return expr; }
        void accept(Visitor* visitor) {
            visitor->visit(this);
        }
};

Unlike with the StatementNodes, some of the ExpressionNodes are leaf nodes, and as such have no child nodes. Others, such as BinaryExpression nodes are internal nodes with right and left children, who's accept() methods are used to traverse them depth first.

class IdExpression : public ExpressionNode {
    private:

    public:
        IdExpression(Token token) : ExpressionNode(token) { }
        string getId() { return getToken().lexeme; }
        void accept(Visitor* visitor) {
            visitor->visit(this);
        }
};

class LiteralExpression : public ExpressionNode {
    private:

    public:
        LiteralExpression(Token token) : ExpressionNode(token) { }
        double eval(Environment& e) { return std::stod(getToken().lexeme); }
        void accept(Visitor* visitor) {
            visitor->visit(this);
        }
};

class BinaryExpression : public ExpressionNode {
    private:
        ExpressionNode* left;
        ExpressionNode* right;
    public:
        BinaryExpression(Token token) : ExpressionNode(token) { }
        void setLeft(ExpressionNode* ll) { left = ll; }
        void setRight(ExpressionNode* rr) { right = rr; }
        ExpressionNode* getLeft() { return left; }
        ExpressionNode* getRight() { return right; }
        void accept(Visitor* visitor) {
            visitor->visit(this);
        }
};

class RelOpExpression : public ExpressionNode {
    private:
        ExpressionNode* left;
        ExpressionNode* right;
    public:
        RelOpExpression(Token token) : ExpressionNode(token) { }
        void setLeft(ExpressionNode* ll) { left = ll; }
        void setRight(ExpressionNode* rr) { right = rr; }
        ExpressionNode* getLeft() { return left; }
        ExpressionNode* getRight() { return right; }
        void accept(Visitor* visitor) {
            visitor->visit(this);
        }
};

class AssignExpression : public ExpressionNode {
    private:
        IdExpression* left;
        ExpressionNode* right;
    public:
        AssignExpression(Token tk) : ExpressionNode(tk) { }
        void setLeft(IdExpression* expr) { left = expr; }
        void setRight(ExpressionNode* expr) { right = expr; }
        IdExpression* getLeft() { return left; }
        ExpressionNode* getRight() { return right; }
        void accept(Visitor* visitor) {
            visitor->visit(this);
        }
};

Concrete Visitors

Now that we have our node structures complete, we can move to implementing the concrete visitor types. The first one we'll implement is the PrettyPrintVisitor. The PrettyPrintVisitr traverses the AST while tracking the depth in the tree of the current node being visited. At each node the type of node being visited is displayed, indented to the amount of the nodes depth. 

class PrintVisitor : public Visitor {
    private:
        int d;
        void enter(string s = "") {
            ++d;
            say(s);
        }
        void leave(string s = "") {
            --d;
        }
        void say(string s) {
            for (int i = 0; i < d; i++) {
                cout<<" ";
            }
            cout<<s<<endl;
        }
    public:
        void visit(StatementList* sl) override {
            for (auto m : sl->getStatements()) {
                m->accept(this);
            }
        }
        void visit(ProgramStatement* ps) override {
            enter("Program");
            ps->getStatement()->accept(this);
            leave();
        }
        void visit(PrintStatement* ps) override {
            enter("print statement");
            ps->getExpression()->accept(this);
            leave();
        }
        void visit(WhileStatement* ws) override {
            enter("While statement");
            ws->getTestExpr()->accept(this);
            ws->getLoopBody()->accept(this);
            leave();
        }
        void visit(ExprStatement* es) override {
            enter("Expr Statement");
            if (es->getExpression() != nullptr)
                es->getExpression()->accept(this);
            leave();
        }
        void visit(IdExpression* idexpr) override {
            enter("Id Expression");
            cout<<idexpr->getToken().lexeme<<endl;
            leave();
        }
        void visit(LiteralExpression* lit) override {
            enter("Literal Expression");
            cout<<lit->getToken().lexeme<<endl;
            leave();
        }
        void visit(AssignExpression* assign) override {
            enter("Assignment Expression");
            assign->getLeft()->accept(this);
            assign->getRight()->accept(this);
            leave();
        }
        void visit(BinaryExpression* bin) override {
            enter("Binary Expression");
            cout<<bin->getToken().lexeme<<endl;
            bin->getLeft()->accept(this);
            bin->getRight()->accept(this);
            leave();
        }
        void visit(RelOpExpression* rel) override {
            enter("Relop Expression");
            rel->getLeft()->accept(this);
            rel->getRight()->accept(this);
            leave();
        }
        void visit(BooleanExpression* boo) override {
            enter("Boolean Constant");
            cout<<boo->getToken().lexeme<<endl;
            leave();
        }
}; 

That wasn't so bad, right? Let's see else can we do with the visitor pattern.

The Whole Point: Easily Adding Operations

Having seen how to traverse the AST using the visitor pattern with the PrettyPrint example,lets utilize the visitor pattern for something a bit more... involved. This time, lets implement a visitor which will actually interpret our AST. We will again implement a concrete visitor class, with methods to visit each node type.

This time instead of implementing methods to pretty print the AST, our Visitor will maintain an environment and operand stack.

class Interpreter : public Visitor {
    private:
        Environment env;
        double operands[31337];
        int n = 0;
        void push(double e) {
            operands[n++] = e;
        }
        double pop() {
            if (n > 0)
                return operands[--n];
            cout<<"Stack underflow."<<endl;
            return 0xff;
        }
        double peek(int k) {
            return operands[(n-1)-k];
        }
    public:
        void visit(StatementList* sl) override {
            for (auto m : sl->getStatements()) {
                m->accept(this);
            }
        }
        void visit(ProgramStatement* ps) override {
            ps->getStatement()->accept(this);
        }
        void visit(PrintStatement* ps) override {
            ps->getExpression()->accept(this);
            cout<<pop()<<endl;
        }
        void visit(WhileStatement* ws) override {
            RelOpExpression* testExpr = ws->getTestExpr();
            StatementList* stmt = ws->getLoopBody();
            for (;;) {
                testExpr->accept(this);
                if (pop()) {
                    stmt->accept(this);
                } else break;
            }
        }
        void visit(ExprStatement* es) override {
            if (es->getExpression() != nullptr)
                es->getExpression()->accept(this);
        }
        void visit(IdExpression* idexpr) override {
            push(env[idexpr->getId()]);
        }
        void visit(LiteralExpression* lit) override {
            push(lit->eval(env));
        }
        void visit(AssignExpression* assign) override {
            string id = assign->getLeft()->getId();
            assign->getRight()->accept(this);
            env[id] = pop();
        }
        void visit(BinaryExpression* bin) override {
            bin->getLeft()->accept(this);
            bin->getRight()->accept(this);
            double rhs = pop();
            double lhs = pop();
            switch (bin->getToken().type) {
                case TK_PLUS:  push(lhs+rhs); break;
                case TK_MINUS: push(lhs-rhs); break;
                case TK_MULT:  push(lhs*rhs); break;
                case TK_DIV:   push(lhs/rhs); break;
                default:
                    break;
            }
        }
        void visit(RelOpExpression* rel) override {
            rel->getLeft()->accept(this);
            rel->getRight()->accept(this);
            double rhs = pop();
            double lhs = pop();
            switch (rel->getToken().type) {
                case TK_EQU: push(lhs == rhs); break;
                case TK_NEQ: push(!(lhs == rhs)); break;
                case TK_LT: push(lhs < rhs); break;
                case TK_GT: push(rhs < lhs); break;
                case TK_LTE: push((lhs < rhs) || (lhs == rhs)); break;
                case TK_GTE: push((rhs < lhs) || (lhs == rhs)); break;
                default:
                    break;
            }
        }
        void visit(BooleanExpression* boo) override {
            push(boo->eval(env));
        }
};

That's the real beauty of the visitor pattern: Want a compiler instead of an interpreter? Implement a code generating visitor. Want to implement lexical scoping? Implement a scope resolution visitor. You get the Idea. Additionally, since we don't have to stuff all of this functionality into each class extending our AST remains simple. Adding a new type is as easy as implementing an accept method, and adding an appropriate method to each visitor.

Well, thats all I've got for today, until next time: Happy Hacking!

A version of the code seen here is available on my github: https://github.com/maxgoren/VisitorPattern


Leave A Comment