Richer error handling.

This commit is contained in:
Zach Klippenstein 2015-07-11 23:18:58 -07:00
parent 0817d29438
commit 34b3a07ca8
19 changed files with 389 additions and 154 deletions

View file

@ -5,6 +5,7 @@ import (
"io"
"strings"
"github.com/zach-klippenstein/goadb/util"
"github.com/zach-klippenstein/goadb/wire"
)
@ -28,19 +29,23 @@ func (c *DeviceClient) String() string {
// get-product is documented, but not implemented in the server.
// TODO(z): Make getProduct exported if get-product is ever implemented in adb.
func (c *DeviceClient) getProduct() (string, error) {
return c.getAttribute("get-product")
attr, err := c.getAttribute("get-product")
return attr, wrapClientError(err, c, "GetProduct")
}
func (c *DeviceClient) GetSerial() (string, error) {
return c.getAttribute("get-serialno")
attr, err := c.getAttribute("get-serialno")
return attr, wrapClientError(err, c, "GetSerial")
}
func (c *DeviceClient) GetDevicePath() (string, error) {
return c.getAttribute("get-devpath")
attr, err := c.getAttribute("get-devpath")
return attr, wrapClientError(err, c, "GetDevicePath")
}
func (c *DeviceClient) GetState() (string, error) {
return c.getAttribute("get-state")
attr, err := c.getAttribute("get-state")
return attr, wrapClientError(err, c, "GetState")
}
/*
@ -62,12 +67,12 @@ contain double quotes.
func (c *DeviceClient) RunCommand(cmd string, args ...string) (string, error) {
cmd, err := prepareCommandLine(cmd, args...)
if err != nil {
return "", err
return "", wrapClientError(err, c, "RunCommand")
}
conn, err := c.dialDevice()
if err != nil {
return "", err
return "", wrapClientError(err, c, "RunCommand")
}
defer conn.Close()
@ -77,18 +82,14 @@ func (c *DeviceClient) RunCommand(cmd string, args ...string) (string, error) {
// We read until the stream is closed.
// So, we can't use conn.RoundTripSingleResponse.
if err = conn.SendMessage([]byte(req)); err != nil {
return "", err
return "", wrapClientError(err, c, "RunCommand")
}
if err = wire.ReadStatusFailureAsError(conn, req); err != nil {
return "", err
return "", wrapClientError(err, c, "RunCommand")
}
resp, err := conn.ReadUntilEof()
if err != nil {
return "", err
}
return string(resp), nil
return string(resp), wrapClientError(err, c, "RunCommand")
}
/*
@ -103,39 +104,42 @@ Source: https://android.googlesource.com/platform/system/core/+/master/adb/SERVI
func (c *DeviceClient) Remount() (string, error) {
conn, err := c.dialDevice()
if err != nil {
return "", err
return "", wrapClientError(err, c, "Remount")
}
defer conn.Close()
resp, err := conn.RoundTripSingleResponse([]byte("remount"))
return string(resp), err
return string(resp), wrapClientError(err, c, "Remount")
}
func (c *DeviceClient) ListDirEntries(path string) (*DirEntries, error) {
conn, err := c.getSyncConn()
if err != nil {
return nil, err
return nil, wrapClientError(err, c, "ListDirEntries(%s)", path)
}
return listDirEntries(conn, path)
entries, err := listDirEntries(conn, path)
return entries, wrapClientError(err, c, "ListDirEntries(%s)", path)
}
func (c *DeviceClient) Stat(path string) (*DirEntry, error) {
conn, err := c.getSyncConn()
if err != nil {
return nil, err
return nil, wrapClientError(err, c, "Stat(%s)", path)
}
return stat(conn, path)
entry, err := stat(conn, path)
return entry, wrapClientError(err, c, "Stat(%s)", path)
}
func (c *DeviceClient) OpenRead(path string) (io.ReadCloser, error) {
conn, err := c.getSyncConn()
if err != nil {
return nil, err
return nil, wrapClientError(err, c, "OpenRead(%s)", path)
}
return receiveFile(conn, path)
reader, err := receiveFile(conn, path)
return reader, wrapClientError(err, c, "OpenRead(%s)", path)
}
// getAttribute returns the first message returned by the server by running
@ -149,18 +153,35 @@ func (c *DeviceClient) getAttribute(attr string) (string, error) {
return string(resp), nil
}
func (c *DeviceClient) getSyncConn() (*wire.SyncConn, error) {
conn, err := c.dialDevice()
if err != nil {
return nil, err
}
// Switch the connection to sync mode.
if err := wire.SendMessageString(conn, "sync:"); err != nil {
return nil, err
}
if err := wire.ReadStatusFailureAsError(conn, "sync"); err != nil {
return nil, err
}
return conn.NewSyncConn(), nil
}
// dialDevice switches the connection to communicate directly with the device
// by requesting the transport defined by the DeviceDescriptor.
func (c *DeviceClient) dialDevice() (*wire.Conn, error) {
conn, err := c.config.Dialer.Dial()
if err != nil {
return nil, fmt.Errorf("error dialing adb server (%s): %+v", c.config.Dialer, err)
return nil, err
}
req := fmt.Sprintf("host:%s", c.descriptor.getTransportDescriptor())
if err = wire.SendMessageString(conn, req); err != nil {
conn.Close()
return nil, fmt.Errorf("error connecting to device '%s': %+v", c.descriptor, err)
return nil, util.WrapErrf(err, "error connecting to device '%s'", c.descriptor)
}
if err = wire.ReadStatusFailureAsError(conn, req); err != nil {
@ -171,33 +192,16 @@ func (c *DeviceClient) dialDevice() (*wire.Conn, error) {
return conn, nil
}
func (c *DeviceClient) getSyncConn() (*wire.SyncConn, error) {
conn, err := c.dialDevice()
if err != nil {
return nil, fmt.Errorf("error connecting to device for sync: %+v", err)
}
// Switch the connection to sync mode.
if err := wire.SendMessageString(conn, "sync:"); err != nil {
return nil, fmt.Errorf("error requesting sync mode: %+v", err)
}
if err := wire.ReadStatusFailureAsError(conn, "sync"); err != nil {
return nil, err
}
return conn.NewSyncConn(), nil
}
// prepareCommandLine validates the command and argument strings, quotes
// arguments if required, and joins them into a valid adb command string.
func prepareCommandLine(cmd string, args ...string) (string, error) {
if isBlank(cmd) {
return "", fmt.Errorf("command cannot be empty")
return "", util.AssertionErrorf("command cannot be empty")
}
for i, arg := range args {
if strings.ContainsRune(arg, '"') {
return "", fmt.Errorf("arg at index %d contains an invalid double quote: %s", i, arg)
return "", util.Errorf(util.ParseError, "arg at index %d contains an invalid double quote: %s", i, arg)
}
if containsWhitespace(arg) {
args[i] = fmt.Sprintf("\"%s\"", arg)

View file

@ -4,6 +4,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zach-klippenstein/goadb/util"
"github.com/zach-klippenstein/goadb/wire"
)
@ -52,12 +53,14 @@ func TestPrepareCommandLineNoArgs(t *testing.T) {
func TestPrepareCommandLineEmptyCommand(t *testing.T) {
_, err := prepareCommandLine("")
assert.EqualError(t, err, "command cannot be empty")
assert.Equal(t, util.AssertionError, code(err))
assert.Equal(t, "command cannot be empty", message(err))
}
func TestPrepareCommandLineBlankCommand(t *testing.T) {
_, err := prepareCommandLine(" ")
assert.EqualError(t, err, "command cannot be empty")
assert.Equal(t, util.AssertionError, code(err))
assert.Equal(t, "command cannot be empty", message(err))
}
func TestPrepareCommandLineCleanArgs(t *testing.T) {
@ -74,5 +77,14 @@ func TestPrepareCommandLineArgWithWhitespaceQuotes(t *testing.T) {
func TestPrepareCommandLineArgWithDoubleQuoteFails(t *testing.T) {
_, err := prepareCommandLine("cmd", "quoted\"arg")
assert.EqualError(t, err, "arg at index 0 contains an invalid double quote: quoted\"arg")
assert.Equal(t, util.ParseError, code(err))
assert.Equal(t, "arg at index 0 contains an invalid double quote: quoted\"arg", message(err))
}
func code(err error) util.ErrCode {
return err.(*util.Err).Code
}
func message(err error) string {
return err.(*util.Err).Message
}

View file

@ -2,8 +2,9 @@ package goadb
import (
"bufio"
"fmt"
"strings"
"github.com/zach-klippenstein/goadb/util"
)
type DeviceInfo struct {
@ -26,7 +27,7 @@ func (d *DeviceInfo) IsUsb() bool {
func newDevice(serial string, attrs map[string]string) (*DeviceInfo, error) {
if serial == "" {
return nil, fmt.Errorf("device serial cannot be blank")
return nil, util.AssertionErrorf("device serial cannot be blank")
}
return &DeviceInfo{
@ -49,9 +50,6 @@ func parseDeviceList(list string, lineParseFunc func(string) (*DeviceInfo, error
}
devices = append(devices, device)
}
if err := scanner.Err(); err != nil {
return nil, err
}
return devices, nil
}
@ -59,7 +57,8 @@ func parseDeviceList(list string, lineParseFunc func(string) (*DeviceInfo, error
func parseDeviceShort(line string) (*DeviceInfo, error) {
fields := strings.Fields(line)
if len(fields) != 2 {
return nil, fmt.Errorf("malformed device line, expected 2 fields but found %d", len(fields))
return nil, util.Errorf(util.ParseError,
"malformed device line, expected 2 fields but found %d", len(fields))
}
return newDevice(fields[0], map[string]string{})
@ -68,7 +67,8 @@ func parseDeviceShort(line string) (*DeviceInfo, error) {
func parseDeviceLong(line string) (*DeviceInfo, error) {
fields := strings.Fields(line)
if len(fields) < 5 {
return nil, fmt.Errorf("malformed device line, expected at least 5 fields but found %d", len(fields))
return nil, util.Errorf(util.ParseError,
"malformed device line, expected at least 5 fields but found %d", len(fields))
}
attrs := parseDeviceAttributes(fields[2:])

View file

@ -5,6 +5,7 @@ import (
"net"
"runtime"
"github.com/zach-klippenstein/goadb/util"
"github.com/zach-klippenstein/goadb/wire"
)
@ -50,9 +51,10 @@ func (d *netDialer) Dial() (*wire.Conn, error) {
host := d.Host
port := d.Port
netConn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", host, port))
address := fmt.Sprintf("%s:%d", host, port)
netConn, err := net.Dial("tcp", address)
if err != nil {
return nil, err
return nil, util.WrapErrorf(err, util.NetworkError, "error dialing %s", address)
}
conn := &wire.Conn{
@ -68,7 +70,6 @@ func (d *netDialer) Dial() (*wire.Conn, error) {
return conn, nil
}
// TODO(zach): Make this unexported.
func roundTripSingleResponse(d Dialer, req string) ([]byte, error) {
conn, err := d.Dial()
if err != nil {

View file

@ -4,6 +4,7 @@ package goadb
import (
"strconv"
"github.com/zach-klippenstein/goadb/util"
"github.com/zach-klippenstein/goadb/wire"
)
@ -46,11 +47,14 @@ func NewHostClient(config ClientConfig) *HostClient {
func (c *HostClient) GetServerVersion() (int, error) {
resp, err := roundTripSingleResponse(c.config.Dialer, "host:version")
if err != nil {
return 0, err
return 0, wrapClientError(err, c, "GetServerVersion")
}
version, err := strconv.ParseInt(string(resp), 16, 32)
return int(version), err
version, err := c.parseServerVersion(resp)
if err != nil {
return 0, wrapClientError(err, c, "GetServerVersion")
}
return version, nil
}
/*
@ -62,12 +66,12 @@ Corresponds to the command:
func (c *HostClient) KillServer() error {
conn, err := c.config.Dialer.Dial()
if err != nil {
return err
return wrapClientError(err, c, "KillServer")
}
defer conn.Close()
if err = wire.SendMessageString(conn, "host:kill"); err != nil {
return err
return wrapClientError(err, c, "KillServer")
}
return nil
@ -82,12 +86,12 @@ Corresponds to the command:
func (c *HostClient) ListDeviceSerials() ([]string, error) {
resp, err := roundTripSingleResponse(c.config.Dialer, "host:devices")
if err != nil {
return nil, err
return nil, wrapClientError(err, c, "ListDeviceSerials")
}
devices, err := parseDeviceList(string(resp), parseDeviceShort)
if err != nil {
return nil, err
return nil, wrapClientError(err, c, "ListDeviceSerials")
}
serials := make([]string, len(devices))
@ -106,8 +110,22 @@ Corresponds to the command:
func (c *HostClient) ListDevices() ([]*DeviceInfo, error) {
resp, err := roundTripSingleResponse(c.config.Dialer, "host:devices-l")
if err != nil {
return nil, err
return nil, wrapClientError(err, c, "ListDevices")
}
return parseDeviceList(string(resp), parseDeviceLong)
devices, err := parseDeviceList(string(resp), parseDeviceLong)
if err != nil {
return nil, wrapClientError(err, c, "ListDevices")
}
return devices, nil
}
func (c *HostClient) parseServerVersion(versionRaw []byte) (int, error) {
versionStr := string(versionRaw)
version, err := strconv.ParseInt(versionStr, 16, 32)
if err != nil {
return 0, util.WrapErrorf(err, util.ParseError,
"error parsing server version: %s", versionStr)
}
return int(version), nil
}

View file

@ -6,6 +6,7 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zach-klippenstein/goadb/util"
"github.com/zach-klippenstein/goadb/wire"
)
@ -46,7 +47,7 @@ func (s *MockServer) ReadStatus() (wire.StatusCode, error) {
func (s *MockServer) ReadMessage() ([]byte, error) {
if s.nextMsgIndex >= len(s.Messages) {
return nil, io.EOF
return nil, util.WrapErrorf(io.EOF, util.NetworkError, "")
}
s.nextMsgIndex++

View file

@ -1,6 +1,10 @@
package goadb
import "os/exec"
import (
"os/exec"
"github.com/zach-klippenstein/goadb/util"
)
/*
StartServer ensures there is a server running.
@ -10,5 +14,6 @@ Currently implemented by just running
*/
func StartServer() error {
cmd := exec.Command("adb", "start-server")
return cmd.Run()
err := cmd.Run()
return util.WrapErrorf(err, util.ServerNotAvailable, "error starting server: %s", err)
}

View file

@ -2,11 +2,11 @@
package goadb
import (
"fmt"
"io"
"os"
"time"
"github.com/zach-klippenstein/goadb/util"
"github.com/zach-klippenstein/goadb/wire"
)
@ -25,7 +25,7 @@ func stat(conn *wire.SyncConn, path string) (*DirEntry, error) {
return nil, err
}
if id != "STAT" {
return nil, fmt.Errorf("expected stat ID 'STAT', but got '%s'", id)
return nil, util.Errorf(util.AssertionError, "expected stat ID 'STAT', but got '%s'", id)
}
return readStat(conn)
@ -56,24 +56,24 @@ func receiveFile(conn *wire.SyncConn, path string) (io.ReadCloser, error) {
func readStat(s wire.SyncScanner) (entry *DirEntry, err error) {
mode, err := s.ReadFileMode()
if err != nil {
err = fmt.Errorf("error reading file mode: %v", err)
err = util.WrapErrf(err, "error reading file mode: %v", err)
return
}
size, err := s.ReadInt32()
if err != nil {
err = fmt.Errorf("error reading file size: %v", err)
err = util.WrapErrf(err, "error reading file size: %v", err)
return
}
mtime, err := s.ReadTime()
if err != nil {
err = fmt.Errorf("error reading file time: %v", err)
err = util.WrapErrf(err, "error reading file time: %v", err)
return
}
// adb doesn't indicate when a file doesn't exist, but will return all zeros.
// Theoretically this could be an actual file, but that's very unlikely.
if mode == os.FileMode(0) && size == 0 && mtime == zeroTime {
return nil, os.ErrNotExist
return nil, util.Errorf(util.FileNoExistError, "file doesn't exist")
}
entry = &DirEntry{

View file

@ -8,6 +8,8 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zach-klippenstein/goadb/util"
"github.com/zach-klippenstein/goadb/wire"
)
@ -26,6 +28,7 @@ func TestStatValid(t *testing.T) {
entry, err := stat(conn, "/thing")
assert.NoError(t, err)
require.NotNil(t, entry)
assert.Equal(t, mode, entry.Mode, "expected os.FileMode %s, got %s", mode, entry.Mode)
assert.Equal(t, int32(4), entry.Size)
assert.Equal(t, someTime, entry.ModifiedAt)
@ -54,5 +57,5 @@ func TestStatNoExist(t *testing.T) {
entry, err := stat(conn, "/")
assert.Nil(t, entry)
assert.Equal(t, os.ErrNotExist, err)
assert.Equal(t, util.FileNoExistError, err.(*util.Err).Code)
}

24
util.go
View file

@ -1,8 +1,12 @@
package goadb
import "strings"
import (
"fmt"
"reflect"
"regexp"
"strings"
"github.com/zach-klippenstein/goadb/util"
)
var (
@ -16,3 +20,21 @@ func containsWhitespace(str string) bool {
func isBlank(str string) bool {
return whitespaceRegex.MatchString(str)
}
func wrapClientError(err error, client interface{}, operation string, args ...interface{}) error {
if err == nil {
return nil
}
if _, ok := err.(*util.Err); !ok {
panic("err is not a *util.Err")
}
clientType := reflect.TypeOf(client)
return &util.Err{
Code: err.(*util.Err).Code,
Cause: err,
Message: fmt.Sprintf("error performing %s on %s", fmt.Sprintf(operation, args...), clientType),
Details: client,
}
}

16
util/errcode_string.go Normal file
View file

@ -0,0 +1,16 @@
// generated by stringer -type=ErrCode; DO NOT EDIT
package util
import "fmt"
const _ErrCode_name = "AssertionErrorParseErrorServerNotAvailableNetworkErrorAdbErrorDeviceNotFoundFileNoExistError"
var _ErrCode_index = [...]uint8{0, 14, 24, 42, 54, 62, 76, 92}
func (i ErrCode) String() string {
if i+1 >= ErrCode(len(_ErrCode_index)) {
return fmt.Sprintf("ErrCode(%d)", i)
}
return _ErrCode_name[_ErrCode_index[i]:_ErrCode_index[i+1]]
}

96
util/error.go Normal file
View file

@ -0,0 +1,96 @@
package util
import "fmt"
// Err is the implementation of error that all goadb functions return.
type Err struct {
// Code is the high-level "type" of error.
Code ErrCode
// Message is a human-readable description of the error.
Message string
// Details is optional, and can be used to associate any auxiliary data with an error.
Details interface{}
// Cause is optional, and points to the more specific error that caused this one.
Cause error
}
var _ error = &Err{}
//go:generate stringer -type=ErrCode
type ErrCode byte
const (
AssertionError ErrCode = iota
ParseError ErrCode = iota
// The server was not available on the request port and could not be started.
ServerNotAvailable ErrCode = iota
// General network error communicating with the server.
NetworkError ErrCode = iota
// The server returned an error message, but we couldn't parse it.
AdbError ErrCode = iota
// The server returned a "device not found" error.
DeviceNotFound ErrCode = iota
// Tried to perform an operation on a path that doesn't exist on the device.
FileNoExistError ErrCode = iota
)
func Errorf(code ErrCode, format string, args ...interface{}) error {
return &Err{
Code: code,
Message: fmt.Sprintf(format, args...),
}
}
/*
WrapErrf returns an *Err that wraps another *Err and has the same ErrCode.
Panics if cause is not an *Err.
To wrap generic errors, use WrapErrorf.
*/
func WrapErrf(cause error, format string, args ...interface{}) error {
if cause == nil {
return nil
}
err := cause.(*Err)
return &Err{
Code: err.Code,
Message: fmt.Sprintf(format, args...),
Cause: err,
}
}
/*
WrapErrorf returns an *Err that wraps another arbitrary error with an ErrCode and a message.
If cause is nil, returns nil, so you can use it like
return util.WrapErrorf(DoSomethingDangerous(), util.NetworkError, "well that didn't work")
If cause is known to be of type *Err, use WrapErrf.
*/
func WrapErrorf(cause error, code ErrCode, format string, args ...interface{}) error {
if cause == nil {
return nil
}
return &Err{
Code: code,
Message: fmt.Sprintf(format, args...),
Cause: cause,
}
}
func AssertionErrorf(format string, args ...interface{}) error {
return &Err{
Code: AssertionError,
Message: fmt.Sprintf(format, args...),
}
}
func (err *Err) Error() string {
msg := fmt.Sprintf("%s: %s", err.Code, err.Message)
if err.Details != nil {
msg = fmt.Sprintf("%s (%+v)", msg, err.Details)
}
return msg
}

View file

@ -1,24 +0,0 @@
package wire
import (
"fmt"
)
type AdbServerError struct {
Request string
ServerMsg string
}
var _ error = &AdbServerError{}
func (e *AdbServerError) Error() string {
if e.Request == "" {
return fmt.Sprintf("server error: %s", e.ServerMsg)
} else {
return fmt.Sprintf("server error for %s request: %s", e.Request, e.ServerMsg)
}
}
func incompleteMessage(description string, actual int, expected int) error {
return fmt.Errorf("incomplete %s: read %d bytes, expecting %d", description, actual, expected)
}

View file

@ -1,5 +1,7 @@
package wire
import "github.com/zach-klippenstein/goadb/util"
const (
// The official implementation of adb imposes an undocumented 255-byte limit
// on messages.
@ -57,8 +59,20 @@ func (conn *Conn) RoundTripSingleResponse(req []byte) (resp []byte, err error) {
}
func (conn *Conn) Close() error {
if err := conn.Sender.Close(); err != nil {
return err
errs := struct {
SenderErr error
ScannerErr error
}{
SenderErr: conn.Sender.Close(),
ScannerErr: conn.Scanner.Close(),
}
return conn.Scanner.Close()
if errs.ScannerErr != nil || errs.SenderErr != nil {
return &util.Err{
Code: util.NetworkError,
Message: "error closing connection",
Details: errs,
}
}
return nil
}

View file

@ -1,10 +1,11 @@
package wire
import (
"fmt"
"io"
"io/ioutil"
"strconv"
"github.com/zach-klippenstein/goadb/util"
)
// StatusCodes are returned by the server. If the code indicates failure, the
@ -54,33 +55,41 @@ func ReadMessageString(s Scanner) (string, error) {
func (s *realScanner) ReadStatus() (StatusCode, error) {
status := make([]byte, 4)
n, err := io.ReadFull(s.reader, status)
if err != nil && err != io.ErrUnexpectedEOF {
return "", err
return "", util.WrapErrorf(err, util.NetworkError, "error reading status")
} else if err == io.ErrUnexpectedEOF {
return StatusCode(status), incompleteMessage("status", n, 4)
return StatusCode(status), errIncompleteMessage("status", n, 4)
}
return StatusCode(status), nil
}
func (s *realScanner) ReadMessage() ([]byte, error) {
var err error
length, err := s.readLength()
if err != nil {
return nil, err
return nil, util.WrapErrorf(err, util.NetworkError, "error reading message length")
}
data := make([]byte, length)
n, err := io.ReadFull(s.reader, data)
if err != nil && err != io.ErrUnexpectedEOF {
return data, fmt.Errorf("error reading message data: %v", err)
return data, util.WrapErrorf(err, util.NetworkError, "error reading message data")
} else if err == io.ErrUnexpectedEOF {
return data, incompleteMessage("message data", n, length)
return data, errIncompleteMessage("message data", n, length)
}
return data, nil
}
func (s *realScanner) ReadUntilEof() ([]byte, error) {
return ioutil.ReadAll(s.reader)
data, err := ioutil.ReadAll(s.reader)
if err != nil {
return nil, util.WrapErrorf(err, util.NetworkError, "error reading until EOF")
}
return data, nil
}
func (s *realScanner) NewSyncScanner() SyncScanner {
@ -88,21 +97,21 @@ func (s *realScanner) NewSyncScanner() SyncScanner {
}
func (s *realScanner) Close() error {
return s.reader.Close()
return util.WrapErrorf(s.reader.Close(), util.NetworkError, "error closing scanner")
}
func (s *realScanner) readLength() (int, error) {
lengthHex := make([]byte, 4)
n, err := io.ReadFull(s.reader, lengthHex)
if err != nil && err != io.ErrUnexpectedEOF {
return 0, err
return 0, util.WrapErrorf(err, util.NetworkError, "error reading length")
} else if err == io.ErrUnexpectedEOF {
return 0, incompleteMessage("length", n, 4)
return 0, errIncompleteMessage("length", n, 4)
}
length, err := strconv.ParseInt(string(lengthHex), 16, 64)
if err != nil {
return 0, fmt.Errorf("invalid hex length: %v", err)
return 0, util.WrapErrorf(err, util.NetworkError, "could not parse hex length %v", lengthHex)
}
// Clip the length to 255, as per the Google implementation.

View file

@ -3,6 +3,8 @@ package wire
import (
"fmt"
"io"
"github.com/zach-klippenstein/goadb/util"
)
// Sender sends messages to the server.
@ -28,7 +30,7 @@ func SendMessageString(s Sender, msg string) error {
func (s *realSender) SendMessage(msg []byte) error {
if len(msg) > MaxMessageLength {
return fmt.Errorf("message length exceeds maximum: %d", len(msg))
return util.AssertionErrorf("message length exceeds maximum: %d", len(msg))
}
lengthAndMsg := fmt.Sprintf("%04x%s", len(msg), msg)
@ -40,7 +42,7 @@ func (s *realSender) NewSyncSender() SyncSender {
}
func (s *realSender) Close() error {
return s.writer.Close()
return util.WrapErrorf(s.writer.Close(), util.NetworkError, "error closing sender")
}
var _ Sender = &realSender{}

View file

@ -2,10 +2,11 @@ package wire
import (
"encoding/binary"
"fmt"
"io"
"os"
"time"
"github.com/zach-klippenstein/goadb/util"
)
type SyncScanner interface {
@ -38,10 +39,10 @@ func NewSyncScanner(r io.Reader) SyncScanner {
func RequireOctetString(s SyncScanner, expected string) error {
actual, err := s.ReadOctetString()
if err != nil {
return fmt.Errorf("expected to read '%s', got err: %v", expected, err)
return util.WrapErrorf(err, util.NetworkError, "expected to read '%s'", expected)
}
if actual != expected {
return fmt.Errorf("expected to read '%s', got '%s'", expected, actual)
return util.AssertionErrorf("expected to read '%s', got '%s'", expected, actual)
}
return nil
}
@ -49,10 +50,11 @@ func RequireOctetString(s SyncScanner, expected string) error {
func (s *realSyncScanner) ReadOctetString() (string, error) {
octet := make([]byte, 4)
n, err := io.ReadFull(s.Reader, octet)
if err != nil && err != io.ErrUnexpectedEOF {
return "", err
return "", util.WrapErrorf(err, util.NetworkError, "error reading octet string from sync scanner")
} else if err == io.ErrUnexpectedEOF {
return "", incompleteMessage("octet", n, 4)
return "", errIncompleteMessage("octet", n, 4)
}
return string(octet), nil
@ -60,20 +62,21 @@ func (s *realSyncScanner) ReadOctetString() (string, error) {
func (s *realSyncScanner) ReadInt32() (int32, error) {
var value int32
err := binary.Read(s.Reader, binary.LittleEndian, &value)
return value, err
return value, util.WrapErrorf(err, util.NetworkError, "error reading int from sync scanner")
}
func (s *realSyncScanner) ReadFileMode() (filemode os.FileMode, err error) {
func (s *realSyncScanner) ReadFileMode() (os.FileMode, error) {
var value uint32
err = binary.Read(s.Reader, binary.LittleEndian, &value)
if err == nil {
filemode = ParseFileModeFromAdb(value)
err := binary.Read(s.Reader, binary.LittleEndian, &value)
if err != nil {
return 0, util.WrapErrorf(err, util.NetworkError, "error reading filemode from sync scanner")
}
return
return ParseFileModeFromAdb(value), nil
}
func (s *realSyncScanner) ReadTime() (time.Time, error) {
seconds, err := s.ReadInt32()
if err != nil {
return time.Time{}, err
return time.Time{}, util.WrapErrorf(err, util.NetworkError, "error reading time from sync scanner")
}
return time.Unix(int64(seconds), 0).UTC(), nil
@ -82,15 +85,15 @@ func (s *realSyncScanner) ReadTime() (time.Time, error) {
func (s *realSyncScanner) ReadString() (string, error) {
length, err := s.ReadInt32()
if err != nil {
return "", err
return "", util.WrapErrorf(err, util.NetworkError, "error reading length from sync scanner")
}
bytes := make([]byte, length)
n, err := io.ReadFull(s.Reader, bytes)
if err != nil && err != io.ErrUnexpectedEOF {
return "", err
} else if err == io.ErrUnexpectedEOF {
return "", incompleteMessage("bytes", n, int(length))
n, rawErr := io.ReadFull(s.Reader, bytes)
if rawErr != nil && rawErr != io.ErrUnexpectedEOF {
return "", util.WrapErrorf(rawErr, util.NetworkError, "error reading string from sync scanner")
} else if rawErr == io.ErrUnexpectedEOF {
return "", errIncompleteMessage("bytes", n, int(length))
}
return string(bytes), nil
@ -98,7 +101,7 @@ func (s *realSyncScanner) ReadString() (string, error) {
func (s *realSyncScanner) ReadBytes() (io.Reader, error) {
length, err := s.ReadInt32()
if err != nil {
return nil, err
return nil, util.WrapErrorf(err, util.NetworkError, "error reading bytes from sync scanner")
}
return io.LimitReader(s.Reader, int64(length)), nil
@ -106,7 +109,7 @@ func (s *realSyncScanner) ReadBytes() (io.Reader, error) {
func (s *realSyncScanner) Close() error {
if closer, ok := s.Reader.(io.Closer); ok {
return closer.Close()
return util.WrapErrorf(closer.Close(), util.NetworkError, "error closing sync scanner")
}
return nil
}

View file

@ -2,10 +2,11 @@ package wire
import (
"encoding/binary"
"fmt"
"io"
"os"
"time"
"github.com/zach-klippenstein/goadb/util"
)
type SyncSender interface {
@ -29,21 +30,28 @@ func NewSyncSender(w io.Writer) SyncSender {
func (s *realSyncSender) SendOctetString(str string) error {
if len(str) != 4 {
return fmt.Errorf("octet string must be exactly 4 bytes: '%s'", str)
return util.AssertionErrorf("octet string must be exactly 4 bytes: '%s'", str)
}
return writeFully(s.Writer, []byte(str))
wrappedErr := util.WrapErrorf(writeFully(s.Writer, []byte(str)),
util.NetworkError, "error sending octet string on sync sender")
return wrappedErr
}
func (s *realSyncSender) SendInt32(val int32) error {
return binary.Write(s.Writer, binary.LittleEndian, val)
return util.WrapErrorf(binary.Write(s.Writer, binary.LittleEndian, val),
util.NetworkError, "error sending int on sync sender")
}
func (s *realSyncSender) SendFileMode(mode os.FileMode) error {
return binary.Write(s.Writer, binary.LittleEndian, mode)
return util.WrapErrorf(binary.Write(s.Writer, binary.LittleEndian, mode),
util.NetworkError, "error sending filemode on sync sender")
}
func (s *realSyncSender) SendTime(t time.Time) error {
return s.SendInt32(int32(t.Unix()))
return util.WrapErrorf(s.SendInt32(int32(t.Unix())),
util.NetworkError, "error sending time on sync sender")
}
func (s *realSyncSender) SendString(str string) error {
@ -51,11 +59,12 @@ func (s *realSyncSender) SendString(str string) error {
if length > MaxChunkSize {
// This limit might not apply to filenames, but it's big enough
// that I don't think it will be a problem.
return fmt.Errorf("str must be <= %d in length", MaxChunkSize)
return util.AssertionErrorf("str must be <= %d in length", MaxChunkSize)
}
if err := s.SendInt32(int32(length)); err != nil {
return err
return util.WrapErrorf(err, util.NetworkError, "error sending string length on sync sender")
}
return writeFully(s.Writer, []byte(str))
return util.WrapErrorf(writeFully(s.Writer, []byte(str)),
util.NetworkError, "error sending string on sync sender")
}

View file

@ -3,41 +3,85 @@ package wire
import (
"fmt"
"io"
"github.com/zach-klippenstein/goadb/util"
)
// ErrorResponseDetails is an error message returned by the server for a particular request.
type ErrorResponseDetails struct {
Request string
ServerMsg string
}
// Reads the status, and if failure, reads the message and returns it as an error.
// If the status is success, doesn't read the message.
// req is just used to populate the AdbError, and can be nil.
func ReadStatusFailureAsError(s Scanner, req string) error {
status, err := s.ReadStatus()
if err != nil {
return fmt.Errorf("error reading status for %s: %+v", req, err)
return util.WrapErrorf(err, util.NetworkError, "error reading status for %s", req)
}
if !status.IsSuccess() {
msg, err := s.ReadMessage()
if err != nil {
return fmt.Errorf("server returned error for %s, but couldn't read the error message: %+v", err)
return util.WrapErrorf(err, util.NetworkError,
"server returned error for %s, but couldn't read the error message", req)
}
return &AdbServerError{
Request: req,
ServerMsg: string(msg),
}
return adbServerError(req, string(msg))
}
return nil
}
func adbServerError(request string, serverMsg string) error {
var msg string
if request == "" {
msg = fmt.Sprintf("server error: %s", serverMsg)
} else {
msg = fmt.Sprintf("server error for %s request: %s", request, serverMsg)
}
errCode := util.AdbError
if serverMsg == "device not found" {
errCode = util.DeviceNotFound
}
return &util.Err{
Code: errCode,
Message: msg,
Details: ErrorResponseDetails{
Request: request,
ServerMsg: serverMsg,
},
}
}
func errIncompleteMessage(description string, actual int, expected int) error {
return &util.Err{
Code: util.NetworkError,
Message: fmt.Sprintf("incomplete %s: read %d bytes, expecting %d", description, actual, expected),
Details: struct {
ActualReadBytes int
ExpectedBytes int
}{
ActualReadBytes: actual,
ExpectedBytes: expected,
},
}
}
// writeFully writes all of data to w.
// Inverse of io.ReadFully().
func writeFully(w io.Writer, data []byte) error {
for len(data) > 0 {
n, err := w.Write(data)
offset := 0
for offset < len(data) {
n, err := w.Write(data[offset:])
if err != nil {
return err
return util.WrapErrorf(err, util.NetworkError, "error writing %d bytes at offset %d", len(data), offset)
}
data = data[n:]
offset += n
}
return nil
}