diff --git a/util/error.go b/util/error.go index 93f212e..ebed6a8 100644 --- a/util/error.go +++ b/util/error.go @@ -94,3 +94,13 @@ func (err *Err) Error() string { } return msg } + +// HasErrCode returns true if err is an *Err and err.Code == code. +func HasErrCode(err error, code ErrCode) bool { + switch err := err.(type) { + case *Err: + return err.Code == code + default: + return false + } +} diff --git a/wire/scanner_test.go b/wire/scanner_test.go index 20b402b..6226404 100644 --- a/wire/scanner_test.go +++ b/wire/scanner_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/zach-klippenstein/goadb/util" ) func TestReadStatusOkay(t *testing.T) { @@ -20,7 +21,7 @@ func TestReadStatusOkay(t *testing.T) { func TestReadIncompleteStatus(t *testing.T) { s := NewScannerString("oka") _, err := s.ReadStatus() - assert.Equal(t, incompleteMessage("status", 3, 4), err) + assert.Equal(t, errIncompleteMessage("status", 3, 4), err) assertEof(t, s) } @@ -35,7 +36,7 @@ func TestReadLength(t *testing.T) { func TestReadIncompleteLength(t *testing.T) { s := NewScannerString("aaa") _, err := s.readLength() - assert.Equal(t, incompleteMessage("length", 3, 4), err) + assert.Equal(t, errIncompleteMessage("length", 3, 4), err) assertEof(t, s) } @@ -78,7 +79,7 @@ func TestReadIncompleteMessage(t *testing.T) { s := NewScannerString("0005hel") msg, err := ReadMessageString(s) assert.Error(t, err) - assert.Equal(t, incompleteMessage("message data", 3, 5), err) + assert.Equal(t, errIncompleteMessage("message data", 3, 5), err) assert.Equal(t, "hel\000\000", msg) assertEof(t, s) } @@ -97,7 +98,7 @@ func NewEofBuffer(str string) *TestReader { func assertEof(t *testing.T, s *realScanner) { msg, err := s.ReadMessage() - assert.Equal(t, io.EOF, err) + assert.True(t, util.HasErrCode(err, util.NetworkError)) assert.Nil(t, msg) } diff --git a/wire/sync_test.go b/wire/sync_test.go index d11a522..1966d55 100644 --- a/wire/sync_test.go +++ b/wire/sync_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/zach-klippenstein/goadb/util" ) var ( @@ -35,7 +36,7 @@ func TestSyncSendOctetStringTooLong(t *testing.T) { var buf bytes.Buffer s := NewSyncSender(&buf) err := s.SendOctetString("hello") - assert.EqualError(t, err, "octet string must be exactly 4 bytes: 'hello'") + assert.Equal(t, util.AssertionErrorf("octet string must be exactly 4 bytes: 'hello'"), err) } func TestSyncReadTime(t *testing.T) { @@ -63,7 +64,7 @@ func TestSyncReadString(t *testing.T) { func TestSyncReadStringTooShort(t *testing.T) { s := NewSyncScanner(strings.NewReader("\005\000\000\000h")) _, err := s.ReadString() - assert.EqualError(t, err, "incomplete bytes: read 1 bytes, expecting 5") + assert.Equal(t, errIncompleteMessage("bytes", 1, 5), err) } func TestSyncSendString(t *testing.T) {