diff --git a/device_client.go b/device_client.go index f02f196..46c1f76 100644 --- a/device_client.go +++ b/device_client.go @@ -5,6 +5,7 @@ import ( "io" "strings" + "github.com/zach-klippenstein/goadb/util" "github.com/zach-klippenstein/goadb/wire" ) @@ -28,19 +29,23 @@ func (c *DeviceClient) String() string { // get-product is documented, but not implemented in the server. // TODO(z): Make getProduct exported if get-product is ever implemented in adb. func (c *DeviceClient) getProduct() (string, error) { - return c.getAttribute("get-product") + attr, err := c.getAttribute("get-product") + return attr, wrapClientError(err, c, "GetProduct") } func (c *DeviceClient) GetSerial() (string, error) { - return c.getAttribute("get-serialno") + attr, err := c.getAttribute("get-serialno") + return attr, wrapClientError(err, c, "GetSerial") } func (c *DeviceClient) GetDevicePath() (string, error) { - return c.getAttribute("get-devpath") + attr, err := c.getAttribute("get-devpath") + return attr, wrapClientError(err, c, "GetDevicePath") } func (c *DeviceClient) GetState() (string, error) { - return c.getAttribute("get-state") + attr, err := c.getAttribute("get-state") + return attr, wrapClientError(err, c, "GetState") } /* @@ -62,12 +67,12 @@ contain double quotes. func (c *DeviceClient) RunCommand(cmd string, args ...string) (string, error) { cmd, err := prepareCommandLine(cmd, args...) if err != nil { - return "", err + return "", wrapClientError(err, c, "RunCommand") } conn, err := c.dialDevice() if err != nil { - return "", err + return "", wrapClientError(err, c, "RunCommand") } defer conn.Close() @@ -77,18 +82,14 @@ func (c *DeviceClient) RunCommand(cmd string, args ...string) (string, error) { // We read until the stream is closed. // So, we can't use conn.RoundTripSingleResponse. if err = conn.SendMessage([]byte(req)); err != nil { - return "", err + return "", wrapClientError(err, c, "RunCommand") } if err = wire.ReadStatusFailureAsError(conn, req); err != nil { - return "", err + return "", wrapClientError(err, c, "RunCommand") } resp, err := conn.ReadUntilEof() - if err != nil { - return "", err - } - - return string(resp), nil + return string(resp), wrapClientError(err, c, "RunCommand") } /* @@ -103,39 +104,42 @@ Source: https://android.googlesource.com/platform/system/core/+/master/adb/SERVI func (c *DeviceClient) Remount() (string, error) { conn, err := c.dialDevice() if err != nil { - return "", err + return "", wrapClientError(err, c, "Remount") } defer conn.Close() resp, err := conn.RoundTripSingleResponse([]byte("remount")) - return string(resp), err + return string(resp), wrapClientError(err, c, "Remount") } func (c *DeviceClient) ListDirEntries(path string) (*DirEntries, error) { conn, err := c.getSyncConn() if err != nil { - return nil, err + return nil, wrapClientError(err, c, "ListDirEntries(%s)", path) } - return listDirEntries(conn, path) + entries, err := listDirEntries(conn, path) + return entries, wrapClientError(err, c, "ListDirEntries(%s)", path) } func (c *DeviceClient) Stat(path string) (*DirEntry, error) { conn, err := c.getSyncConn() if err != nil { - return nil, err + return nil, wrapClientError(err, c, "Stat(%s)", path) } - return stat(conn, path) + entry, err := stat(conn, path) + return entry, wrapClientError(err, c, "Stat(%s)", path) } func (c *DeviceClient) OpenRead(path string) (io.ReadCloser, error) { conn, err := c.getSyncConn() if err != nil { - return nil, err + return nil, wrapClientError(err, c, "OpenRead(%s)", path) } - return receiveFile(conn, path) + reader, err := receiveFile(conn, path) + return reader, wrapClientError(err, c, "OpenRead(%s)", path) } // getAttribute returns the first message returned by the server by running @@ -149,18 +153,35 @@ func (c *DeviceClient) getAttribute(attr string) (string, error) { return string(resp), nil } +func (c *DeviceClient) getSyncConn() (*wire.SyncConn, error) { + conn, err := c.dialDevice() + if err != nil { + return nil, err + } + + // Switch the connection to sync mode. + if err := wire.SendMessageString(conn, "sync:"); err != nil { + return nil, err + } + if err := wire.ReadStatusFailureAsError(conn, "sync"); err != nil { + return nil, err + } + + return conn.NewSyncConn(), nil +} + // dialDevice switches the connection to communicate directly with the device // by requesting the transport defined by the DeviceDescriptor. func (c *DeviceClient) dialDevice() (*wire.Conn, error) { conn, err := c.config.Dialer.Dial() if err != nil { - return nil, fmt.Errorf("error dialing adb server (%s): %+v", c.config.Dialer, err) + return nil, err } req := fmt.Sprintf("host:%s", c.descriptor.getTransportDescriptor()) if err = wire.SendMessageString(conn, req); err != nil { conn.Close() - return nil, fmt.Errorf("error connecting to device '%s': %+v", c.descriptor, err) + return nil, util.WrapErrf(err, "error connecting to device '%s'", c.descriptor) } if err = wire.ReadStatusFailureAsError(conn, req); err != nil { @@ -171,33 +192,16 @@ func (c *DeviceClient) dialDevice() (*wire.Conn, error) { return conn, nil } -func (c *DeviceClient) getSyncConn() (*wire.SyncConn, error) { - conn, err := c.dialDevice() - if err != nil { - return nil, fmt.Errorf("error connecting to device for sync: %+v", err) - } - - // Switch the connection to sync mode. - if err := wire.SendMessageString(conn, "sync:"); err != nil { - return nil, fmt.Errorf("error requesting sync mode: %+v", err) - } - if err := wire.ReadStatusFailureAsError(conn, "sync"); err != nil { - return nil, err - } - - return conn.NewSyncConn(), nil -} - // prepareCommandLine validates the command and argument strings, quotes // arguments if required, and joins them into a valid adb command string. func prepareCommandLine(cmd string, args ...string) (string, error) { if isBlank(cmd) { - return "", fmt.Errorf("command cannot be empty") + return "", util.AssertionErrorf("command cannot be empty") } for i, arg := range args { if strings.ContainsRune(arg, '"') { - return "", fmt.Errorf("arg at index %d contains an invalid double quote: %s", i, arg) + return "", util.Errorf(util.ParseError, "arg at index %d contains an invalid double quote: %s", i, arg) } if containsWhitespace(arg) { args[i] = fmt.Sprintf("\"%s\"", arg) diff --git a/device_client_test.go b/device_client_test.go index 18b8a52..fa55e16 100644 --- a/device_client_test.go +++ b/device_client_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/zach-klippenstein/goadb/util" "github.com/zach-klippenstein/goadb/wire" ) @@ -52,12 +53,14 @@ func TestPrepareCommandLineNoArgs(t *testing.T) { func TestPrepareCommandLineEmptyCommand(t *testing.T) { _, err := prepareCommandLine("") - assert.EqualError(t, err, "command cannot be empty") + assert.Equal(t, util.AssertionError, code(err)) + assert.Equal(t, "command cannot be empty", message(err)) } func TestPrepareCommandLineBlankCommand(t *testing.T) { _, err := prepareCommandLine(" ") - assert.EqualError(t, err, "command cannot be empty") + assert.Equal(t, util.AssertionError, code(err)) + assert.Equal(t, "command cannot be empty", message(err)) } func TestPrepareCommandLineCleanArgs(t *testing.T) { @@ -74,5 +77,14 @@ func TestPrepareCommandLineArgWithWhitespaceQuotes(t *testing.T) { func TestPrepareCommandLineArgWithDoubleQuoteFails(t *testing.T) { _, err := prepareCommandLine("cmd", "quoted\"arg") - assert.EqualError(t, err, "arg at index 0 contains an invalid double quote: quoted\"arg") + assert.Equal(t, util.ParseError, code(err)) + assert.Equal(t, "arg at index 0 contains an invalid double quote: quoted\"arg", message(err)) +} + +func code(err error) util.ErrCode { + return err.(*util.Err).Code +} + +func message(err error) string { + return err.(*util.Err).Message } diff --git a/device_info.go b/device_info.go index c1af642..d80a656 100644 --- a/device_info.go +++ b/device_info.go @@ -2,8 +2,9 @@ package goadb import ( "bufio" - "fmt" "strings" + + "github.com/zach-klippenstein/goadb/util" ) type DeviceInfo struct { @@ -26,7 +27,7 @@ func (d *DeviceInfo) IsUsb() bool { func newDevice(serial string, attrs map[string]string) (*DeviceInfo, error) { if serial == "" { - return nil, fmt.Errorf("device serial cannot be blank") + return nil, util.AssertionErrorf("device serial cannot be blank") } return &DeviceInfo{ @@ -49,9 +50,6 @@ func parseDeviceList(list string, lineParseFunc func(string) (*DeviceInfo, error } devices = append(devices, device) } - if err := scanner.Err(); err != nil { - return nil, err - } return devices, nil } @@ -59,7 +57,8 @@ func parseDeviceList(list string, lineParseFunc func(string) (*DeviceInfo, error func parseDeviceShort(line string) (*DeviceInfo, error) { fields := strings.Fields(line) if len(fields) != 2 { - return nil, fmt.Errorf("malformed device line, expected 2 fields but found %d", len(fields)) + return nil, util.Errorf(util.ParseError, + "malformed device line, expected 2 fields but found %d", len(fields)) } return newDevice(fields[0], map[string]string{}) @@ -68,7 +67,8 @@ func parseDeviceShort(line string) (*DeviceInfo, error) { func parseDeviceLong(line string) (*DeviceInfo, error) { fields := strings.Fields(line) if len(fields) < 5 { - return nil, fmt.Errorf("malformed device line, expected at least 5 fields but found %d", len(fields)) + return nil, util.Errorf(util.ParseError, + "malformed device line, expected at least 5 fields but found %d", len(fields)) } attrs := parseDeviceAttributes(fields[2:]) diff --git a/dialer.go b/dialer.go index c60a2c0..1d1a45a 100644 --- a/dialer.go +++ b/dialer.go @@ -5,6 +5,7 @@ import ( "net" "runtime" + "github.com/zach-klippenstein/goadb/util" "github.com/zach-klippenstein/goadb/wire" ) @@ -50,9 +51,10 @@ func (d *netDialer) Dial() (*wire.Conn, error) { host := d.Host port := d.Port - netConn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", host, port)) + address := fmt.Sprintf("%s:%d", host, port) + netConn, err := net.Dial("tcp", address) if err != nil { - return nil, err + return nil, util.WrapErrorf(err, util.NetworkError, "error dialing %s", address) } conn := &wire.Conn{ @@ -68,7 +70,6 @@ func (d *netDialer) Dial() (*wire.Conn, error) { return conn, nil } -// TODO(zach): Make this unexported. func roundTripSingleResponse(d Dialer, req string) ([]byte, error) { conn, err := d.Dial() if err != nil { diff --git a/host_client.go b/host_client.go index a9a9cad..e2be71a 100644 --- a/host_client.go +++ b/host_client.go @@ -4,6 +4,7 @@ package goadb import ( "strconv" + "github.com/zach-klippenstein/goadb/util" "github.com/zach-klippenstein/goadb/wire" ) @@ -46,11 +47,14 @@ func NewHostClient(config ClientConfig) *HostClient { func (c *HostClient) GetServerVersion() (int, error) { resp, err := roundTripSingleResponse(c.config.Dialer, "host:version") if err != nil { - return 0, err + return 0, wrapClientError(err, c, "GetServerVersion") } - version, err := strconv.ParseInt(string(resp), 16, 32) - return int(version), err + version, err := c.parseServerVersion(resp) + if err != nil { + return 0, wrapClientError(err, c, "GetServerVersion") + } + return version, nil } /* @@ -62,12 +66,12 @@ Corresponds to the command: func (c *HostClient) KillServer() error { conn, err := c.config.Dialer.Dial() if err != nil { - return err + return wrapClientError(err, c, "KillServer") } defer conn.Close() if err = wire.SendMessageString(conn, "host:kill"); err != nil { - return err + return wrapClientError(err, c, "KillServer") } return nil @@ -82,12 +86,12 @@ Corresponds to the command: func (c *HostClient) ListDeviceSerials() ([]string, error) { resp, err := roundTripSingleResponse(c.config.Dialer, "host:devices") if err != nil { - return nil, err + return nil, wrapClientError(err, c, "ListDeviceSerials") } devices, err := parseDeviceList(string(resp), parseDeviceShort) if err != nil { - return nil, err + return nil, wrapClientError(err, c, "ListDeviceSerials") } serials := make([]string, len(devices)) @@ -106,8 +110,22 @@ Corresponds to the command: func (c *HostClient) ListDevices() ([]*DeviceInfo, error) { resp, err := roundTripSingleResponse(c.config.Dialer, "host:devices-l") if err != nil { - return nil, err + return nil, wrapClientError(err, c, "ListDevices") } - return parseDeviceList(string(resp), parseDeviceLong) + devices, err := parseDeviceList(string(resp), parseDeviceLong) + if err != nil { + return nil, wrapClientError(err, c, "ListDevices") + } + return devices, nil +} + +func (c *HostClient) parseServerVersion(versionRaw []byte) (int, error) { + versionStr := string(versionRaw) + version, err := strconv.ParseInt(versionStr, 16, 32) + if err != nil { + return 0, util.WrapErrorf(err, util.ParseError, + "error parsing server version: %s", versionStr) + } + return int(version), nil } diff --git a/host_client_test.go b/host_client_test.go index 7f26b1d..fb2ec71 100644 --- a/host_client_test.go +++ b/host_client_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/zach-klippenstein/goadb/util" "github.com/zach-klippenstein/goadb/wire" ) @@ -46,7 +47,7 @@ func (s *MockServer) ReadStatus() (wire.StatusCode, error) { func (s *MockServer) ReadMessage() ([]byte, error) { if s.nextMsgIndex >= len(s.Messages) { - return nil, io.EOF + return nil, util.WrapErrorf(io.EOF, util.NetworkError, "") } s.nextMsgIndex++ diff --git a/server_controller.go b/server_controller.go index c9cd01e..9025a25 100644 --- a/server_controller.go +++ b/server_controller.go @@ -1,6 +1,10 @@ package goadb -import "os/exec" +import ( + "os/exec" + + "github.com/zach-klippenstein/goadb/util" +) /* StartServer ensures there is a server running. @@ -10,5 +14,6 @@ Currently implemented by just running */ func StartServer() error { cmd := exec.Command("adb", "start-server") - return cmd.Run() + err := cmd.Run() + return util.WrapErrorf(err, util.ServerNotAvailable, "error starting server: %s", err) } diff --git a/sync_client.go b/sync_client.go index ac40470..4ec31a0 100644 --- a/sync_client.go +++ b/sync_client.go @@ -2,11 +2,11 @@ package goadb import ( - "fmt" "io" "os" "time" + "github.com/zach-klippenstein/goadb/util" "github.com/zach-klippenstein/goadb/wire" ) @@ -25,7 +25,7 @@ func stat(conn *wire.SyncConn, path string) (*DirEntry, error) { return nil, err } if id != "STAT" { - return nil, fmt.Errorf("expected stat ID 'STAT', but got '%s'", id) + return nil, util.Errorf(util.AssertionError, "expected stat ID 'STAT', but got '%s'", id) } return readStat(conn) @@ -56,24 +56,24 @@ func receiveFile(conn *wire.SyncConn, path string) (io.ReadCloser, error) { func readStat(s wire.SyncScanner) (entry *DirEntry, err error) { mode, err := s.ReadFileMode() if err != nil { - err = fmt.Errorf("error reading file mode: %v", err) + err = util.WrapErrf(err, "error reading file mode: %v", err) return } size, err := s.ReadInt32() if err != nil { - err = fmt.Errorf("error reading file size: %v", err) + err = util.WrapErrf(err, "error reading file size: %v", err) return } mtime, err := s.ReadTime() if err != nil { - err = fmt.Errorf("error reading file time: %v", err) + err = util.WrapErrf(err, "error reading file time: %v", err) return } // adb doesn't indicate when a file doesn't exist, but will return all zeros. // Theoretically this could be an actual file, but that's very unlikely. if mode == os.FileMode(0) && size == 0 && mtime == zeroTime { - return nil, os.ErrNotExist + return nil, util.Errorf(util.FileNoExistError, "file doesn't exist") } entry = &DirEntry{ diff --git a/sync_client_test.go b/sync_client_test.go index ffff2f2..050bd8f 100644 --- a/sync_client_test.go +++ b/sync_client_test.go @@ -8,6 +8,8 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zach-klippenstein/goadb/util" "github.com/zach-klippenstein/goadb/wire" ) @@ -26,6 +28,7 @@ func TestStatValid(t *testing.T) { entry, err := stat(conn, "/thing") assert.NoError(t, err) + require.NotNil(t, entry) assert.Equal(t, mode, entry.Mode, "expected os.FileMode %s, got %s", mode, entry.Mode) assert.Equal(t, int32(4), entry.Size) assert.Equal(t, someTime, entry.ModifiedAt) @@ -54,5 +57,5 @@ func TestStatNoExist(t *testing.T) { entry, err := stat(conn, "/") assert.Nil(t, entry) - assert.Equal(t, os.ErrNotExist, err) + assert.Equal(t, util.FileNoExistError, err.(*util.Err).Code) } diff --git a/util.go b/util.go index e6f7c19..6291f23 100644 --- a/util.go +++ b/util.go @@ -1,8 +1,12 @@ package goadb -import "strings" import ( + "fmt" + "reflect" "regexp" + "strings" + + "github.com/zach-klippenstein/goadb/util" ) var ( @@ -16,3 +20,21 @@ func containsWhitespace(str string) bool { func isBlank(str string) bool { return whitespaceRegex.MatchString(str) } + +func wrapClientError(err error, client interface{}, operation string, args ...interface{}) error { + if err == nil { + return nil + } + if _, ok := err.(*util.Err); !ok { + panic("err is not a *util.Err") + } + + clientType := reflect.TypeOf(client) + + return &util.Err{ + Code: err.(*util.Err).Code, + Cause: err, + Message: fmt.Sprintf("error performing %s on %s", fmt.Sprintf(operation, args...), clientType), + Details: client, + } +} diff --git a/util/errcode_string.go b/util/errcode_string.go new file mode 100644 index 0000000..5bb24c6 --- /dev/null +++ b/util/errcode_string.go @@ -0,0 +1,16 @@ +// generated by stringer -type=ErrCode; DO NOT EDIT + +package util + +import "fmt" + +const _ErrCode_name = "AssertionErrorParseErrorServerNotAvailableNetworkErrorAdbErrorDeviceNotFoundFileNoExistError" + +var _ErrCode_index = [...]uint8{0, 14, 24, 42, 54, 62, 76, 92} + +func (i ErrCode) String() string { + if i+1 >= ErrCode(len(_ErrCode_index)) { + return fmt.Sprintf("ErrCode(%d)", i) + } + return _ErrCode_name[_ErrCode_index[i]:_ErrCode_index[i+1]] +} diff --git a/util/error.go b/util/error.go new file mode 100644 index 0000000..93f212e --- /dev/null +++ b/util/error.go @@ -0,0 +1,96 @@ +package util + +import "fmt" + +// Err is the implementation of error that all goadb functions return. +type Err struct { + // Code is the high-level "type" of error. + Code ErrCode + // Message is a human-readable description of the error. + Message string + // Details is optional, and can be used to associate any auxiliary data with an error. + Details interface{} + // Cause is optional, and points to the more specific error that caused this one. + Cause error +} + +var _ error = &Err{} + +//go:generate stringer -type=ErrCode +type ErrCode byte + +const ( + AssertionError ErrCode = iota + ParseError ErrCode = iota + // The server was not available on the request port and could not be started. + ServerNotAvailable ErrCode = iota + // General network error communicating with the server. + NetworkError ErrCode = iota + // The server returned an error message, but we couldn't parse it. + AdbError ErrCode = iota + // The server returned a "device not found" error. + DeviceNotFound ErrCode = iota + // Tried to perform an operation on a path that doesn't exist on the device. + FileNoExistError ErrCode = iota +) + +func Errorf(code ErrCode, format string, args ...interface{}) error { + return &Err{ + Code: code, + Message: fmt.Sprintf(format, args...), + } +} + +/* +WrapErrf returns an *Err that wraps another *Err and has the same ErrCode. +Panics if cause is not an *Err. + +To wrap generic errors, use WrapErrorf. +*/ +func WrapErrf(cause error, format string, args ...interface{}) error { + if cause == nil { + return nil + } + + err := cause.(*Err) + return &Err{ + Code: err.Code, + Message: fmt.Sprintf(format, args...), + Cause: err, + } +} + +/* +WrapErrorf returns an *Err that wraps another arbitrary error with an ErrCode and a message. + +If cause is nil, returns nil, so you can use it like + return util.WrapErrorf(DoSomethingDangerous(), util.NetworkError, "well that didn't work") + +If cause is known to be of type *Err, use WrapErrf. +*/ +func WrapErrorf(cause error, code ErrCode, format string, args ...interface{}) error { + if cause == nil { + return nil + } + + return &Err{ + Code: code, + Message: fmt.Sprintf(format, args...), + Cause: cause, + } +} + +func AssertionErrorf(format string, args ...interface{}) error { + return &Err{ + Code: AssertionError, + Message: fmt.Sprintf(format, args...), + } +} + +func (err *Err) Error() string { + msg := fmt.Sprintf("%s: %s", err.Code, err.Message) + if err.Details != nil { + msg = fmt.Sprintf("%s (%+v)", msg, err.Details) + } + return msg +} diff --git a/wire/adb_server_error.go b/wire/adb_server_error.go deleted file mode 100644 index df79b28..0000000 --- a/wire/adb_server_error.go +++ /dev/null @@ -1,24 +0,0 @@ -package wire - -import ( - "fmt" -) - -type AdbServerError struct { - Request string - ServerMsg string -} - -var _ error = &AdbServerError{} - -func (e *AdbServerError) Error() string { - if e.Request == "" { - return fmt.Sprintf("server error: %s", e.ServerMsg) - } else { - return fmt.Sprintf("server error for %s request: %s", e.Request, e.ServerMsg) - } -} - -func incompleteMessage(description string, actual int, expected int) error { - return fmt.Errorf("incomplete %s: read %d bytes, expecting %d", description, actual, expected) -} diff --git a/wire/conn.go b/wire/conn.go index 88fbb86..c976ebc 100644 --- a/wire/conn.go +++ b/wire/conn.go @@ -1,5 +1,7 @@ package wire +import "github.com/zach-klippenstein/goadb/util" + const ( // The official implementation of adb imposes an undocumented 255-byte limit // on messages. @@ -57,8 +59,20 @@ func (conn *Conn) RoundTripSingleResponse(req []byte) (resp []byte, err error) { } func (conn *Conn) Close() error { - if err := conn.Sender.Close(); err != nil { - return err + errs := struct { + SenderErr error + ScannerErr error + }{ + SenderErr: conn.Sender.Close(), + ScannerErr: conn.Scanner.Close(), } - return conn.Scanner.Close() + + if errs.ScannerErr != nil || errs.SenderErr != nil { + return &util.Err{ + Code: util.NetworkError, + Message: "error closing connection", + Details: errs, + } + } + return nil } diff --git a/wire/scanner.go b/wire/scanner.go index 880b648..23f4bb0 100644 --- a/wire/scanner.go +++ b/wire/scanner.go @@ -1,10 +1,11 @@ package wire import ( - "fmt" "io" "io/ioutil" "strconv" + + "github.com/zach-klippenstein/goadb/util" ) // StatusCodes are returned by the server. If the code indicates failure, the @@ -54,33 +55,41 @@ func ReadMessageString(s Scanner) (string, error) { func (s *realScanner) ReadStatus() (StatusCode, error) { status := make([]byte, 4) n, err := io.ReadFull(s.reader, status) + if err != nil && err != io.ErrUnexpectedEOF { - return "", err + return "", util.WrapErrorf(err, util.NetworkError, "error reading status") } else if err == io.ErrUnexpectedEOF { - return StatusCode(status), incompleteMessage("status", n, 4) + return StatusCode(status), errIncompleteMessage("status", n, 4) } return StatusCode(status), nil } func (s *realScanner) ReadMessage() ([]byte, error) { + var err error + length, err := s.readLength() if err != nil { - return nil, err + return nil, util.WrapErrorf(err, util.NetworkError, "error reading message length") } data := make([]byte, length) n, err := io.ReadFull(s.reader, data) + if err != nil && err != io.ErrUnexpectedEOF { - return data, fmt.Errorf("error reading message data: %v", err) + return data, util.WrapErrorf(err, util.NetworkError, "error reading message data") } else if err == io.ErrUnexpectedEOF { - return data, incompleteMessage("message data", n, length) + return data, errIncompleteMessage("message data", n, length) } return data, nil } func (s *realScanner) ReadUntilEof() ([]byte, error) { - return ioutil.ReadAll(s.reader) + data, err := ioutil.ReadAll(s.reader) + if err != nil { + return nil, util.WrapErrorf(err, util.NetworkError, "error reading until EOF") + } + return data, nil } func (s *realScanner) NewSyncScanner() SyncScanner { @@ -88,21 +97,21 @@ func (s *realScanner) NewSyncScanner() SyncScanner { } func (s *realScanner) Close() error { - return s.reader.Close() + return util.WrapErrorf(s.reader.Close(), util.NetworkError, "error closing scanner") } func (s *realScanner) readLength() (int, error) { lengthHex := make([]byte, 4) n, err := io.ReadFull(s.reader, lengthHex) if err != nil && err != io.ErrUnexpectedEOF { - return 0, err + return 0, util.WrapErrorf(err, util.NetworkError, "error reading length") } else if err == io.ErrUnexpectedEOF { - return 0, incompleteMessage("length", n, 4) + return 0, errIncompleteMessage("length", n, 4) } length, err := strconv.ParseInt(string(lengthHex), 16, 64) if err != nil { - return 0, fmt.Errorf("invalid hex length: %v", err) + return 0, util.WrapErrorf(err, util.NetworkError, "could not parse hex length %v", lengthHex) } // Clip the length to 255, as per the Google implementation. diff --git a/wire/sender.go b/wire/sender.go index d80363f..c333e11 100644 --- a/wire/sender.go +++ b/wire/sender.go @@ -3,6 +3,8 @@ package wire import ( "fmt" "io" + + "github.com/zach-klippenstein/goadb/util" ) // Sender sends messages to the server. @@ -28,7 +30,7 @@ func SendMessageString(s Sender, msg string) error { func (s *realSender) SendMessage(msg []byte) error { if len(msg) > MaxMessageLength { - return fmt.Errorf("message length exceeds maximum: %d", len(msg)) + return util.AssertionErrorf("message length exceeds maximum: %d", len(msg)) } lengthAndMsg := fmt.Sprintf("%04x%s", len(msg), msg) @@ -40,7 +42,7 @@ func (s *realSender) NewSyncSender() SyncSender { } func (s *realSender) Close() error { - return s.writer.Close() + return util.WrapErrorf(s.writer.Close(), util.NetworkError, "error closing sender") } var _ Sender = &realSender{} diff --git a/wire/sync_scanner.go b/wire/sync_scanner.go index 34c8a8b..803d65a 100644 --- a/wire/sync_scanner.go +++ b/wire/sync_scanner.go @@ -2,10 +2,11 @@ package wire import ( "encoding/binary" - "fmt" "io" "os" "time" + + "github.com/zach-klippenstein/goadb/util" ) type SyncScanner interface { @@ -38,10 +39,10 @@ func NewSyncScanner(r io.Reader) SyncScanner { func RequireOctetString(s SyncScanner, expected string) error { actual, err := s.ReadOctetString() if err != nil { - return fmt.Errorf("expected to read '%s', got err: %v", expected, err) + return util.WrapErrorf(err, util.NetworkError, "expected to read '%s'", expected) } if actual != expected { - return fmt.Errorf("expected to read '%s', got '%s'", expected, actual) + return util.AssertionErrorf("expected to read '%s', got '%s'", expected, actual) } return nil } @@ -49,10 +50,11 @@ func RequireOctetString(s SyncScanner, expected string) error { func (s *realSyncScanner) ReadOctetString() (string, error) { octet := make([]byte, 4) n, err := io.ReadFull(s.Reader, octet) + if err != nil && err != io.ErrUnexpectedEOF { - return "", err + return "", util.WrapErrorf(err, util.NetworkError, "error reading octet string from sync scanner") } else if err == io.ErrUnexpectedEOF { - return "", incompleteMessage("octet", n, 4) + return "", errIncompleteMessage("octet", n, 4) } return string(octet), nil @@ -60,20 +62,21 @@ func (s *realSyncScanner) ReadOctetString() (string, error) { func (s *realSyncScanner) ReadInt32() (int32, error) { var value int32 err := binary.Read(s.Reader, binary.LittleEndian, &value) - return value, err + return value, util.WrapErrorf(err, util.NetworkError, "error reading int from sync scanner") } -func (s *realSyncScanner) ReadFileMode() (filemode os.FileMode, err error) { +func (s *realSyncScanner) ReadFileMode() (os.FileMode, error) { var value uint32 - err = binary.Read(s.Reader, binary.LittleEndian, &value) - if err == nil { - filemode = ParseFileModeFromAdb(value) + err := binary.Read(s.Reader, binary.LittleEndian, &value) + if err != nil { + return 0, util.WrapErrorf(err, util.NetworkError, "error reading filemode from sync scanner") } - return + return ParseFileModeFromAdb(value), nil + } func (s *realSyncScanner) ReadTime() (time.Time, error) { seconds, err := s.ReadInt32() if err != nil { - return time.Time{}, err + return time.Time{}, util.WrapErrorf(err, util.NetworkError, "error reading time from sync scanner") } return time.Unix(int64(seconds), 0).UTC(), nil @@ -82,15 +85,15 @@ func (s *realSyncScanner) ReadTime() (time.Time, error) { func (s *realSyncScanner) ReadString() (string, error) { length, err := s.ReadInt32() if err != nil { - return "", err + return "", util.WrapErrorf(err, util.NetworkError, "error reading length from sync scanner") } bytes := make([]byte, length) - n, err := io.ReadFull(s.Reader, bytes) - if err != nil && err != io.ErrUnexpectedEOF { - return "", err - } else if err == io.ErrUnexpectedEOF { - return "", incompleteMessage("bytes", n, int(length)) + n, rawErr := io.ReadFull(s.Reader, bytes) + if rawErr != nil && rawErr != io.ErrUnexpectedEOF { + return "", util.WrapErrorf(rawErr, util.NetworkError, "error reading string from sync scanner") + } else if rawErr == io.ErrUnexpectedEOF { + return "", errIncompleteMessage("bytes", n, int(length)) } return string(bytes), nil @@ -98,7 +101,7 @@ func (s *realSyncScanner) ReadString() (string, error) { func (s *realSyncScanner) ReadBytes() (io.Reader, error) { length, err := s.ReadInt32() if err != nil { - return nil, err + return nil, util.WrapErrorf(err, util.NetworkError, "error reading bytes from sync scanner") } return io.LimitReader(s.Reader, int64(length)), nil @@ -106,7 +109,7 @@ func (s *realSyncScanner) ReadBytes() (io.Reader, error) { func (s *realSyncScanner) Close() error { if closer, ok := s.Reader.(io.Closer); ok { - return closer.Close() + return util.WrapErrorf(closer.Close(), util.NetworkError, "error closing sync scanner") } return nil } diff --git a/wire/sync_sender.go b/wire/sync_sender.go index 345f91e..15dfde8 100644 --- a/wire/sync_sender.go +++ b/wire/sync_sender.go @@ -2,10 +2,11 @@ package wire import ( "encoding/binary" - "fmt" "io" "os" "time" + + "github.com/zach-klippenstein/goadb/util" ) type SyncSender interface { @@ -29,21 +30,28 @@ func NewSyncSender(w io.Writer) SyncSender { func (s *realSyncSender) SendOctetString(str string) error { if len(str) != 4 { - return fmt.Errorf("octet string must be exactly 4 bytes: '%s'", str) + return util.AssertionErrorf("octet string must be exactly 4 bytes: '%s'", str) } - return writeFully(s.Writer, []byte(str)) + + wrappedErr := util.WrapErrorf(writeFully(s.Writer, []byte(str)), + util.NetworkError, "error sending octet string on sync sender") + + return wrappedErr } func (s *realSyncSender) SendInt32(val int32) error { - return binary.Write(s.Writer, binary.LittleEndian, val) + return util.WrapErrorf(binary.Write(s.Writer, binary.LittleEndian, val), + util.NetworkError, "error sending int on sync sender") } func (s *realSyncSender) SendFileMode(mode os.FileMode) error { - return binary.Write(s.Writer, binary.LittleEndian, mode) + return util.WrapErrorf(binary.Write(s.Writer, binary.LittleEndian, mode), + util.NetworkError, "error sending filemode on sync sender") } func (s *realSyncSender) SendTime(t time.Time) error { - return s.SendInt32(int32(t.Unix())) + return util.WrapErrorf(s.SendInt32(int32(t.Unix())), + util.NetworkError, "error sending time on sync sender") } func (s *realSyncSender) SendString(str string) error { @@ -51,11 +59,12 @@ func (s *realSyncSender) SendString(str string) error { if length > MaxChunkSize { // This limit might not apply to filenames, but it's big enough // that I don't think it will be a problem. - return fmt.Errorf("str must be <= %d in length", MaxChunkSize) + return util.AssertionErrorf("str must be <= %d in length", MaxChunkSize) } if err := s.SendInt32(int32(length)); err != nil { - return err + return util.WrapErrorf(err, util.NetworkError, "error sending string length on sync sender") } - return writeFully(s.Writer, []byte(str)) + return util.WrapErrorf(writeFully(s.Writer, []byte(str)), + util.NetworkError, "error sending string on sync sender") } diff --git a/wire/util.go b/wire/util.go index cddb3f5..2da81cb 100644 --- a/wire/util.go +++ b/wire/util.go @@ -3,41 +3,85 @@ package wire import ( "fmt" "io" + + "github.com/zach-klippenstein/goadb/util" ) +// ErrorResponseDetails is an error message returned by the server for a particular request. +type ErrorResponseDetails struct { + Request string + ServerMsg string +} + // Reads the status, and if failure, reads the message and returns it as an error. // If the status is success, doesn't read the message. // req is just used to populate the AdbError, and can be nil. func ReadStatusFailureAsError(s Scanner, req string) error { status, err := s.ReadStatus() if err != nil { - return fmt.Errorf("error reading status for %s: %+v", req, err) + return util.WrapErrorf(err, util.NetworkError, "error reading status for %s", req) } if !status.IsSuccess() { msg, err := s.ReadMessage() if err != nil { - return fmt.Errorf("server returned error for %s, but couldn't read the error message: %+v", err) + return util.WrapErrorf(err, util.NetworkError, + "server returned error for %s, but couldn't read the error message", req) } - return &AdbServerError{ - Request: req, - ServerMsg: string(msg), - } + return adbServerError(req, string(msg)) } return nil } +func adbServerError(request string, serverMsg string) error { + var msg string + if request == "" { + msg = fmt.Sprintf("server error: %s", serverMsg) + } else { + msg = fmt.Sprintf("server error for %s request: %s", request, serverMsg) + } + + errCode := util.AdbError + if serverMsg == "device not found" { + errCode = util.DeviceNotFound + } + + return &util.Err{ + Code: errCode, + Message: msg, + Details: ErrorResponseDetails{ + Request: request, + ServerMsg: serverMsg, + }, + } +} + +func errIncompleteMessage(description string, actual int, expected int) error { + return &util.Err{ + Code: util.NetworkError, + Message: fmt.Sprintf("incomplete %s: read %d bytes, expecting %d", description, actual, expected), + Details: struct { + ActualReadBytes int + ExpectedBytes int + }{ + ActualReadBytes: actual, + ExpectedBytes: expected, + }, + } +} + // writeFully writes all of data to w. // Inverse of io.ReadFully(). func writeFully(w io.Writer, data []byte) error { - for len(data) > 0 { - n, err := w.Write(data) + offset := 0 + for offset < len(data) { + n, err := w.Write(data[offset:]) if err != nil { - return err + return util.WrapErrorf(err, util.NetworkError, "error writing %d bytes at offset %d", len(data), offset) } - data = data[n:] + offset += n } return nil }