Ty Scales

Writing a Mock Library in Go

This post will walk through the core components that make the mock package from Testify work. It consists of 3 parts:

Mock is a struct with an Array of Calls and a mutual exclusion lock to ensure only one goroutine accesses the resource at a time.

type Mock struct {
    ExpectedCalls []*Call
    mutex sync.Mutex
}

type Call struct {
    Parent *Mock
    Method string
    Arguments []interface{}
    ReturnArguments []interface{}
}

Mock needs the ability to add Calls, and Calls needs the ability to add return values. This is done by defining a method On that takes a method name and a variable number of arguments. A Call is returned, which has its return values set using Return. Finally, Mock can be embedded in any struct that implements an interface to make it a mock, and provide access to creating calls. The basic example looks like this

type Adder interface {
    AddOne(n int) int 
}

type mockAdder struct {
    Mock
}

func (m *mockAdder) AddOne(n int) int {
    //TODO get the return values from the mock
}

mockAdder := &mockAdder
mockAdder.On("AddOne", 1).Return(2)

On simply defines a new call and appends it to the expected arguments. Return, likewise, sets the ReturnArguments of the Call

func newCall(parent *Mock, methodName string, methodArguments ...interface{}) *Call {
    return &Call{
        Parent: parent,
        Method: methodName,
        Arguments: methodArguments,
        ReturnArguments: make([]interface{},0),
    }
}
func (m *Mock) On(methodName string, arguments ...interface{}) *Call {
    m.mutex.Lock()
    defer m.mutex.Unlock()
    c := NewCall(m, methodName, arguments...)
    m.ExpectedCalls = append(m.ExpectedCalls, c)
}
func (c *Call) Return(returnArguments ...interface{}) *Call {
    c.lock()
    defer c.unlock()
    
    c.ReturnArguments = returnArguments
    return c
}

now we need to actually implement the method of the interface.This is where the majority of our code will be written. To get the arguments from the mock, we need to

First, we define a Called Method to take in the arguments and get the name of the function that was called

// get the name of the method that called this function, 
// search for it in the array of calls in this mock
func (m *Mock) Called(arguments ...interface{}) []interface{} {
    pc, _, _, ok := runtime.Caller(1)
    if !ok {
        panic("Couldn't get the caller information")
    }
    functionPath := runtime.FuncForPC(pc).Name()
    parts := strings.Split(functionPath, ".")
    functionName := parts[len(parts)-1]

    return m.MethodCalled(functionName, arguments...)
}

now we’ll need a function to find the expected call, and when we find the call that matches the method name, a function to match the arguments.

func (m *Mock) findExpectedCall(method string, arguments ...interface{}) *Call {
    var expectedCall *Call

    for i, call := range m.ExpectedCalls {
        if call.Method == method {
            _, diffCount := argumentDiff(call.Arguments, arguments)
            if diffCount == 0 {
                expectedCall = call
            }
        }
    }
    return expectedCall
}

argumentDiff iterates through each argument of the call looking for any differences. With the exception of some special cases, this is mostly done using reflect.DeepEqual

func argumentDiff(callArgs ...interface{}, arguments ...interface{}) int {
    var differences int
    maxArgCount := len(callArgs)
    if len(arguments) > maxArgCount {
        maxArgCount = len(arguments)
    }
    for i := 0; i < maxArgCount; i++ {
        var actual, expected interface{}
        if len(arguments) <= i {
            actual = "(Missing)"
        } else {
            actual = arguments[i]
        }

        if len(callArgs) <= i {
            expected = "(Missing)"
        } else {
            expected = callArgs[i]
        }
        if !objectsAreEqual(actual, expected) {
            differences++
        }     
    }
    return differences
}  
func objectsAreEqual(expected, actual interface{}) bool {
 	switch exp := expected.(type) {
	case []byte:
		act, ok := actual.([]byte)
		if !ok {
			return false
		}
		if exp == nil || act == nil {
			return exp == nil && act == nil
		}
		return bytes.Equal(exp, act)
	default:
		return reflect.DeepEqual(expected, actual)  
}

we can now complete the AddOne call of our mock by getting the arguments with Called and returning them.

type Adder interface {
	AddOne(n int) int
}

type mockAdder struct {
	Mock
}

func (m *mockAdder) AddOne(n int) int {
	args := m.Called(n)
	if len(args) < 1 {
		panic("error: expected 1 return argument for mocked function AddOne")
	}
	return args[0].(int)
}

func TestMock(t *testing.T) {
	m := &mockAdder{}
	m.On("AddOne", 1).Return(2)

	res := m.AddOne(1)
	if res != 2 {
		t.Fatalf("expected 2, got %d", res)
	}
}

The test passes! from here, you can extend out the argumentDiff function to handle more special cases. For Example, if you want a placeholder that always matches, you could export a constant Anything

const Anything = "mock.Anything"

func ArgumentDiff(callArgs ...interface{}, arguments ...interface{}) int {
    ...
    for i := 0; i < maxArgCount; i++ {
        ...
        if !objectsAreEqual(actual, expected) && !objectsAreEqual(expected, Anything) {
            differences++
        }     
    }

}

Now we can accept any value for a mock and get a passing test. This is especially useful for functions that take in a context that don’t need exact matches.

type Adder interface {
    AddOne(ctx context.Context, n int) int 
}

type mockAdder struct {
    Mock
}

func TestMock(t *testing.T) {
	m := &mockAdder{}
	m.On("AddOne", mock.Anything, 1).Return(2)

	res := m.AddOne(context.Background(), 1)
	if res != 2 {
		t.Fatalf("expected 2, got %d", res)
	}
}

The full code is available on Github