Richer error handling.
This commit is contained in:
parent
0817d29438
commit
34b3a07ca8
|
@ -5,6 +5,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/zach-klippenstein/goadb/util"
|
||||||
"github.com/zach-klippenstein/goadb/wire"
|
"github.com/zach-klippenstein/goadb/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -28,19 +29,23 @@ func (c *DeviceClient) String() string {
|
||||||
// get-product is documented, but not implemented in the server.
|
// get-product is documented, but not implemented in the server.
|
||||||
// TODO(z): Make getProduct exported if get-product is ever implemented in adb.
|
// TODO(z): Make getProduct exported if get-product is ever implemented in adb.
|
||||||
func (c *DeviceClient) getProduct() (string, error) {
|
func (c *DeviceClient) getProduct() (string, error) {
|
||||||
return c.getAttribute("get-product")
|
attr, err := c.getAttribute("get-product")
|
||||||
|
return attr, wrapClientError(err, c, "GetProduct")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *DeviceClient) GetSerial() (string, error) {
|
func (c *DeviceClient) GetSerial() (string, error) {
|
||||||
return c.getAttribute("get-serialno")
|
attr, err := c.getAttribute("get-serialno")
|
||||||
|
return attr, wrapClientError(err, c, "GetSerial")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *DeviceClient) GetDevicePath() (string, error) {
|
func (c *DeviceClient) GetDevicePath() (string, error) {
|
||||||
return c.getAttribute("get-devpath")
|
attr, err := c.getAttribute("get-devpath")
|
||||||
|
return attr, wrapClientError(err, c, "GetDevicePath")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *DeviceClient) GetState() (string, error) {
|
func (c *DeviceClient) GetState() (string, error) {
|
||||||
return c.getAttribute("get-state")
|
attr, err := c.getAttribute("get-state")
|
||||||
|
return attr, wrapClientError(err, c, "GetState")
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -62,12 +67,12 @@ contain double quotes.
|
||||||
func (c *DeviceClient) RunCommand(cmd string, args ...string) (string, error) {
|
func (c *DeviceClient) RunCommand(cmd string, args ...string) (string, error) {
|
||||||
cmd, err := prepareCommandLine(cmd, args...)
|
cmd, err := prepareCommandLine(cmd, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", wrapClientError(err, c, "RunCommand")
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := c.dialDevice()
|
conn, err := c.dialDevice()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", wrapClientError(err, c, "RunCommand")
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
|
@ -77,18 +82,14 @@ func (c *DeviceClient) RunCommand(cmd string, args ...string) (string, error) {
|
||||||
// We read until the stream is closed.
|
// We read until the stream is closed.
|
||||||
// So, we can't use conn.RoundTripSingleResponse.
|
// So, we can't use conn.RoundTripSingleResponse.
|
||||||
if err = conn.SendMessage([]byte(req)); err != nil {
|
if err = conn.SendMessage([]byte(req)); err != nil {
|
||||||
return "", err
|
return "", wrapClientError(err, c, "RunCommand")
|
||||||
}
|
}
|
||||||
if err = wire.ReadStatusFailureAsError(conn, req); err != nil {
|
if err = wire.ReadStatusFailureAsError(conn, req); err != nil {
|
||||||
return "", err
|
return "", wrapClientError(err, c, "RunCommand")
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := conn.ReadUntilEof()
|
resp, err := conn.ReadUntilEof()
|
||||||
if err != nil {
|
return string(resp), wrapClientError(err, c, "RunCommand")
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(resp), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -103,39 +104,42 @@ Source: https://android.googlesource.com/platform/system/core/+/master/adb/SERVI
|
||||||
func (c *DeviceClient) Remount() (string, error) {
|
func (c *DeviceClient) Remount() (string, error) {
|
||||||
conn, err := c.dialDevice()
|
conn, err := c.dialDevice()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", wrapClientError(err, c, "Remount")
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
resp, err := conn.RoundTripSingleResponse([]byte("remount"))
|
resp, err := conn.RoundTripSingleResponse([]byte("remount"))
|
||||||
return string(resp), err
|
return string(resp), wrapClientError(err, c, "Remount")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *DeviceClient) ListDirEntries(path string) (*DirEntries, error) {
|
func (c *DeviceClient) ListDirEntries(path string) (*DirEntries, error) {
|
||||||
conn, err := c.getSyncConn()
|
conn, err := c.getSyncConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, wrapClientError(err, c, "ListDirEntries(%s)", path)
|
||||||
}
|
}
|
||||||
|
|
||||||
return listDirEntries(conn, path)
|
entries, err := listDirEntries(conn, path)
|
||||||
|
return entries, wrapClientError(err, c, "ListDirEntries(%s)", path)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *DeviceClient) Stat(path string) (*DirEntry, error) {
|
func (c *DeviceClient) Stat(path string) (*DirEntry, error) {
|
||||||
conn, err := c.getSyncConn()
|
conn, err := c.getSyncConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, wrapClientError(err, c, "Stat(%s)", path)
|
||||||
}
|
}
|
||||||
|
|
||||||
return stat(conn, path)
|
entry, err := stat(conn, path)
|
||||||
|
return entry, wrapClientError(err, c, "Stat(%s)", path)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *DeviceClient) OpenRead(path string) (io.ReadCloser, error) {
|
func (c *DeviceClient) OpenRead(path string) (io.ReadCloser, error) {
|
||||||
conn, err := c.getSyncConn()
|
conn, err := c.getSyncConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, wrapClientError(err, c, "OpenRead(%s)", path)
|
||||||
}
|
}
|
||||||
|
|
||||||
return receiveFile(conn, path)
|
reader, err := receiveFile(conn, path)
|
||||||
|
return reader, wrapClientError(err, c, "OpenRead(%s)", path)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getAttribute returns the first message returned by the server by running
|
// getAttribute returns the first message returned by the server by running
|
||||||
|
@ -149,18 +153,35 @@ func (c *DeviceClient) getAttribute(attr string) (string, error) {
|
||||||
return string(resp), nil
|
return string(resp), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *DeviceClient) getSyncConn() (*wire.SyncConn, error) {
|
||||||
|
conn, err := c.dialDevice()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Switch the connection to sync mode.
|
||||||
|
if err := wire.SendMessageString(conn, "sync:"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := wire.ReadStatusFailureAsError(conn, "sync"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn.NewSyncConn(), nil
|
||||||
|
}
|
||||||
|
|
||||||
// dialDevice switches the connection to communicate directly with the device
|
// dialDevice switches the connection to communicate directly with the device
|
||||||
// by requesting the transport defined by the DeviceDescriptor.
|
// by requesting the transport defined by the DeviceDescriptor.
|
||||||
func (c *DeviceClient) dialDevice() (*wire.Conn, error) {
|
func (c *DeviceClient) dialDevice() (*wire.Conn, error) {
|
||||||
conn, err := c.config.Dialer.Dial()
|
conn, err := c.config.Dialer.Dial()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error dialing adb server (%s): %+v", c.config.Dialer, err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
req := fmt.Sprintf("host:%s", c.descriptor.getTransportDescriptor())
|
req := fmt.Sprintf("host:%s", c.descriptor.getTransportDescriptor())
|
||||||
if err = wire.SendMessageString(conn, req); err != nil {
|
if err = wire.SendMessageString(conn, req); err != nil {
|
||||||
conn.Close()
|
conn.Close()
|
||||||
return nil, fmt.Errorf("error connecting to device '%s': %+v", c.descriptor, err)
|
return nil, util.WrapErrf(err, "error connecting to device '%s'", c.descriptor)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = wire.ReadStatusFailureAsError(conn, req); err != nil {
|
if err = wire.ReadStatusFailureAsError(conn, req); err != nil {
|
||||||
|
@ -171,33 +192,16 @@ func (c *DeviceClient) dialDevice() (*wire.Conn, error) {
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *DeviceClient) getSyncConn() (*wire.SyncConn, error) {
|
|
||||||
conn, err := c.dialDevice()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error connecting to device for sync: %+v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Switch the connection to sync mode.
|
|
||||||
if err := wire.SendMessageString(conn, "sync:"); err != nil {
|
|
||||||
return nil, fmt.Errorf("error requesting sync mode: %+v", err)
|
|
||||||
}
|
|
||||||
if err := wire.ReadStatusFailureAsError(conn, "sync"); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return conn.NewSyncConn(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepareCommandLine validates the command and argument strings, quotes
|
// prepareCommandLine validates the command and argument strings, quotes
|
||||||
// arguments if required, and joins them into a valid adb command string.
|
// arguments if required, and joins them into a valid adb command string.
|
||||||
func prepareCommandLine(cmd string, args ...string) (string, error) {
|
func prepareCommandLine(cmd string, args ...string) (string, error) {
|
||||||
if isBlank(cmd) {
|
if isBlank(cmd) {
|
||||||
return "", fmt.Errorf("command cannot be empty")
|
return "", util.AssertionErrorf("command cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, arg := range args {
|
for i, arg := range args {
|
||||||
if strings.ContainsRune(arg, '"') {
|
if strings.ContainsRune(arg, '"') {
|
||||||
return "", fmt.Errorf("arg at index %d contains an invalid double quote: %s", i, arg)
|
return "", util.Errorf(util.ParseError, "arg at index %d contains an invalid double quote: %s", i, arg)
|
||||||
}
|
}
|
||||||
if containsWhitespace(arg) {
|
if containsWhitespace(arg) {
|
||||||
args[i] = fmt.Sprintf("\"%s\"", arg)
|
args[i] = fmt.Sprintf("\"%s\"", arg)
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zach-klippenstein/goadb/util"
|
||||||
"github.com/zach-klippenstein/goadb/wire"
|
"github.com/zach-klippenstein/goadb/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -52,12 +53,14 @@ func TestPrepareCommandLineNoArgs(t *testing.T) {
|
||||||
|
|
||||||
func TestPrepareCommandLineEmptyCommand(t *testing.T) {
|
func TestPrepareCommandLineEmptyCommand(t *testing.T) {
|
||||||
_, err := prepareCommandLine("")
|
_, err := prepareCommandLine("")
|
||||||
assert.EqualError(t, err, "command cannot be empty")
|
assert.Equal(t, util.AssertionError, code(err))
|
||||||
|
assert.Equal(t, "command cannot be empty", message(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPrepareCommandLineBlankCommand(t *testing.T) {
|
func TestPrepareCommandLineBlankCommand(t *testing.T) {
|
||||||
_, err := prepareCommandLine(" ")
|
_, err := prepareCommandLine(" ")
|
||||||
assert.EqualError(t, err, "command cannot be empty")
|
assert.Equal(t, util.AssertionError, code(err))
|
||||||
|
assert.Equal(t, "command cannot be empty", message(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPrepareCommandLineCleanArgs(t *testing.T) {
|
func TestPrepareCommandLineCleanArgs(t *testing.T) {
|
||||||
|
@ -74,5 +77,14 @@ func TestPrepareCommandLineArgWithWhitespaceQuotes(t *testing.T) {
|
||||||
|
|
||||||
func TestPrepareCommandLineArgWithDoubleQuoteFails(t *testing.T) {
|
func TestPrepareCommandLineArgWithDoubleQuoteFails(t *testing.T) {
|
||||||
_, err := prepareCommandLine("cmd", "quoted\"arg")
|
_, err := prepareCommandLine("cmd", "quoted\"arg")
|
||||||
assert.EqualError(t, err, "arg at index 0 contains an invalid double quote: quoted\"arg")
|
assert.Equal(t, util.ParseError, code(err))
|
||||||
|
assert.Equal(t, "arg at index 0 contains an invalid double quote: quoted\"arg", message(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
func code(err error) util.ErrCode {
|
||||||
|
return err.(*util.Err).Code
|
||||||
|
}
|
||||||
|
|
||||||
|
func message(err error) string {
|
||||||
|
return err.(*util.Err).Message
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,8 +2,9 @@ package goadb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"fmt"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/zach-klippenstein/goadb/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DeviceInfo struct {
|
type DeviceInfo struct {
|
||||||
|
@ -26,7 +27,7 @@ func (d *DeviceInfo) IsUsb() bool {
|
||||||
|
|
||||||
func newDevice(serial string, attrs map[string]string) (*DeviceInfo, error) {
|
func newDevice(serial string, attrs map[string]string) (*DeviceInfo, error) {
|
||||||
if serial == "" {
|
if serial == "" {
|
||||||
return nil, fmt.Errorf("device serial cannot be blank")
|
return nil, util.AssertionErrorf("device serial cannot be blank")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &DeviceInfo{
|
return &DeviceInfo{
|
||||||
|
@ -49,9 +50,6 @@ func parseDeviceList(list string, lineParseFunc func(string) (*DeviceInfo, error
|
||||||
}
|
}
|
||||||
devices = append(devices, device)
|
devices = append(devices, device)
|
||||||
}
|
}
|
||||||
if err := scanner.Err(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return devices, nil
|
return devices, nil
|
||||||
}
|
}
|
||||||
|
@ -59,7 +57,8 @@ func parseDeviceList(list string, lineParseFunc func(string) (*DeviceInfo, error
|
||||||
func parseDeviceShort(line string) (*DeviceInfo, error) {
|
func parseDeviceShort(line string) (*DeviceInfo, error) {
|
||||||
fields := strings.Fields(line)
|
fields := strings.Fields(line)
|
||||||
if len(fields) != 2 {
|
if len(fields) != 2 {
|
||||||
return nil, fmt.Errorf("malformed device line, expected 2 fields but found %d", len(fields))
|
return nil, util.Errorf(util.ParseError,
|
||||||
|
"malformed device line, expected 2 fields but found %d", len(fields))
|
||||||
}
|
}
|
||||||
|
|
||||||
return newDevice(fields[0], map[string]string{})
|
return newDevice(fields[0], map[string]string{})
|
||||||
|
@ -68,7 +67,8 @@ func parseDeviceShort(line string) (*DeviceInfo, error) {
|
||||||
func parseDeviceLong(line string) (*DeviceInfo, error) {
|
func parseDeviceLong(line string) (*DeviceInfo, error) {
|
||||||
fields := strings.Fields(line)
|
fields := strings.Fields(line)
|
||||||
if len(fields) < 5 {
|
if len(fields) < 5 {
|
||||||
return nil, fmt.Errorf("malformed device line, expected at least 5 fields but found %d", len(fields))
|
return nil, util.Errorf(util.ParseError,
|
||||||
|
"malformed device line, expected at least 5 fields but found %d", len(fields))
|
||||||
}
|
}
|
||||||
|
|
||||||
attrs := parseDeviceAttributes(fields[2:])
|
attrs := parseDeviceAttributes(fields[2:])
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
|
"github.com/zach-klippenstein/goadb/util"
|
||||||
"github.com/zach-klippenstein/goadb/wire"
|
"github.com/zach-klippenstein/goadb/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -50,9 +51,10 @@ func (d *netDialer) Dial() (*wire.Conn, error) {
|
||||||
host := d.Host
|
host := d.Host
|
||||||
port := d.Port
|
port := d.Port
|
||||||
|
|
||||||
netConn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", host, port))
|
address := fmt.Sprintf("%s:%d", host, port)
|
||||||
|
netConn, err := net.Dial("tcp", address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.WrapErrorf(err, util.NetworkError, "error dialing %s", address)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn := &wire.Conn{
|
conn := &wire.Conn{
|
||||||
|
@ -68,7 +70,6 @@ func (d *netDialer) Dial() (*wire.Conn, error) {
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(zach): Make this unexported.
|
|
||||||
func roundTripSingleResponse(d Dialer, req string) ([]byte, error) {
|
func roundTripSingleResponse(d Dialer, req string) ([]byte, error) {
|
||||||
conn, err := d.Dial()
|
conn, err := d.Dial()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -4,6 +4,7 @@ package goadb
|
||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/zach-klippenstein/goadb/util"
|
||||||
"github.com/zach-klippenstein/goadb/wire"
|
"github.com/zach-klippenstein/goadb/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -46,11 +47,14 @@ func NewHostClient(config ClientConfig) *HostClient {
|
||||||
func (c *HostClient) GetServerVersion() (int, error) {
|
func (c *HostClient) GetServerVersion() (int, error) {
|
||||||
resp, err := roundTripSingleResponse(c.config.Dialer, "host:version")
|
resp, err := roundTripSingleResponse(c.config.Dialer, "host:version")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, wrapClientError(err, c, "GetServerVersion")
|
||||||
}
|
}
|
||||||
|
|
||||||
version, err := strconv.ParseInt(string(resp), 16, 32)
|
version, err := c.parseServerVersion(resp)
|
||||||
return int(version), err
|
if err != nil {
|
||||||
|
return 0, wrapClientError(err, c, "GetServerVersion")
|
||||||
|
}
|
||||||
|
return version, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -62,12 +66,12 @@ Corresponds to the command:
|
||||||
func (c *HostClient) KillServer() error {
|
func (c *HostClient) KillServer() error {
|
||||||
conn, err := c.config.Dialer.Dial()
|
conn, err := c.config.Dialer.Dial()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return wrapClientError(err, c, "KillServer")
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
if err = wire.SendMessageString(conn, "host:kill"); err != nil {
|
if err = wire.SendMessageString(conn, "host:kill"); err != nil {
|
||||||
return err
|
return wrapClientError(err, c, "KillServer")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -82,12 +86,12 @@ Corresponds to the command:
|
||||||
func (c *HostClient) ListDeviceSerials() ([]string, error) {
|
func (c *HostClient) ListDeviceSerials() ([]string, error) {
|
||||||
resp, err := roundTripSingleResponse(c.config.Dialer, "host:devices")
|
resp, err := roundTripSingleResponse(c.config.Dialer, "host:devices")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, wrapClientError(err, c, "ListDeviceSerials")
|
||||||
}
|
}
|
||||||
|
|
||||||
devices, err := parseDeviceList(string(resp), parseDeviceShort)
|
devices, err := parseDeviceList(string(resp), parseDeviceShort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, wrapClientError(err, c, "ListDeviceSerials")
|
||||||
}
|
}
|
||||||
|
|
||||||
serials := make([]string, len(devices))
|
serials := make([]string, len(devices))
|
||||||
|
@ -106,8 +110,22 @@ Corresponds to the command:
|
||||||
func (c *HostClient) ListDevices() ([]*DeviceInfo, error) {
|
func (c *HostClient) ListDevices() ([]*DeviceInfo, error) {
|
||||||
resp, err := roundTripSingleResponse(c.config.Dialer, "host:devices-l")
|
resp, err := roundTripSingleResponse(c.config.Dialer, "host:devices-l")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, wrapClientError(err, c, "ListDevices")
|
||||||
}
|
}
|
||||||
|
|
||||||
return parseDeviceList(string(resp), parseDeviceLong)
|
devices, err := parseDeviceList(string(resp), parseDeviceLong)
|
||||||
|
if err != nil {
|
||||||
|
return nil, wrapClientError(err, c, "ListDevices")
|
||||||
|
}
|
||||||
|
return devices, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HostClient) parseServerVersion(versionRaw []byte) (int, error) {
|
||||||
|
versionStr := string(versionRaw)
|
||||||
|
version, err := strconv.ParseInt(versionStr, 16, 32)
|
||||||
|
if err != nil {
|
||||||
|
return 0, util.WrapErrorf(err, util.ParseError,
|
||||||
|
"error parsing server version: %s", versionStr)
|
||||||
|
}
|
||||||
|
return int(version), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zach-klippenstein/goadb/util"
|
||||||
"github.com/zach-klippenstein/goadb/wire"
|
"github.com/zach-klippenstein/goadb/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -46,7 +47,7 @@ func (s *MockServer) ReadStatus() (wire.StatusCode, error) {
|
||||||
|
|
||||||
func (s *MockServer) ReadMessage() ([]byte, error) {
|
func (s *MockServer) ReadMessage() ([]byte, error) {
|
||||||
if s.nextMsgIndex >= len(s.Messages) {
|
if s.nextMsgIndex >= len(s.Messages) {
|
||||||
return nil, io.EOF
|
return nil, util.WrapErrorf(io.EOF, util.NetworkError, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
s.nextMsgIndex++
|
s.nextMsgIndex++
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
package goadb
|
package goadb
|
||||||
|
|
||||||
import "os/exec"
|
import (
|
||||||
|
"os/exec"
|
||||||
|
|
||||||
|
"github.com/zach-klippenstein/goadb/util"
|
||||||
|
)
|
||||||
|
|
||||||
/*
|
/*
|
||||||
StartServer ensures there is a server running.
|
StartServer ensures there is a server running.
|
||||||
|
@ -10,5 +14,6 @@ Currently implemented by just running
|
||||||
*/
|
*/
|
||||||
func StartServer() error {
|
func StartServer() error {
|
||||||
cmd := exec.Command("adb", "start-server")
|
cmd := exec.Command("adb", "start-server")
|
||||||
return cmd.Run()
|
err := cmd.Run()
|
||||||
|
return util.WrapErrorf(err, util.ServerNotAvailable, "error starting server: %s", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,11 +2,11 @@
|
||||||
package goadb
|
package goadb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/zach-klippenstein/goadb/util"
|
||||||
"github.com/zach-klippenstein/goadb/wire"
|
"github.com/zach-klippenstein/goadb/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ func stat(conn *wire.SyncConn, path string) (*DirEntry, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if id != "STAT" {
|
if id != "STAT" {
|
||||||
return nil, fmt.Errorf("expected stat ID 'STAT', but got '%s'", id)
|
return nil, util.Errorf(util.AssertionError, "expected stat ID 'STAT', but got '%s'", id)
|
||||||
}
|
}
|
||||||
|
|
||||||
return readStat(conn)
|
return readStat(conn)
|
||||||
|
@ -56,24 +56,24 @@ func receiveFile(conn *wire.SyncConn, path string) (io.ReadCloser, error) {
|
||||||
func readStat(s wire.SyncScanner) (entry *DirEntry, err error) {
|
func readStat(s wire.SyncScanner) (entry *DirEntry, err error) {
|
||||||
mode, err := s.ReadFileMode()
|
mode, err := s.ReadFileMode()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("error reading file mode: %v", err)
|
err = util.WrapErrf(err, "error reading file mode: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
size, err := s.ReadInt32()
|
size, err := s.ReadInt32()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("error reading file size: %v", err)
|
err = util.WrapErrf(err, "error reading file size: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
mtime, err := s.ReadTime()
|
mtime, err := s.ReadTime()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("error reading file time: %v", err)
|
err = util.WrapErrf(err, "error reading file time: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// adb doesn't indicate when a file doesn't exist, but will return all zeros.
|
// adb doesn't indicate when a file doesn't exist, but will return all zeros.
|
||||||
// Theoretically this could be an actual file, but that's very unlikely.
|
// Theoretically this could be an actual file, but that's very unlikely.
|
||||||
if mode == os.FileMode(0) && size == 0 && mtime == zeroTime {
|
if mode == os.FileMode(0) && size == 0 && mtime == zeroTime {
|
||||||
return nil, os.ErrNotExist
|
return nil, util.Errorf(util.FileNoExistError, "file doesn't exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
entry = &DirEntry{
|
entry = &DirEntry{
|
||||||
|
|
|
@ -8,6 +8,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/zach-klippenstein/goadb/util"
|
||||||
"github.com/zach-klippenstein/goadb/wire"
|
"github.com/zach-klippenstein/goadb/wire"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -26,6 +28,7 @@ func TestStatValid(t *testing.T) {
|
||||||
|
|
||||||
entry, err := stat(conn, "/thing")
|
entry, err := stat(conn, "/thing")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
require.NotNil(t, entry)
|
||||||
assert.Equal(t, mode, entry.Mode, "expected os.FileMode %s, got %s", mode, entry.Mode)
|
assert.Equal(t, mode, entry.Mode, "expected os.FileMode %s, got %s", mode, entry.Mode)
|
||||||
assert.Equal(t, int32(4), entry.Size)
|
assert.Equal(t, int32(4), entry.Size)
|
||||||
assert.Equal(t, someTime, entry.ModifiedAt)
|
assert.Equal(t, someTime, entry.ModifiedAt)
|
||||||
|
@ -54,5 +57,5 @@ func TestStatNoExist(t *testing.T) {
|
||||||
|
|
||||||
entry, err := stat(conn, "/")
|
entry, err := stat(conn, "/")
|
||||||
assert.Nil(t, entry)
|
assert.Nil(t, entry)
|
||||||
assert.Equal(t, os.ErrNotExist, err)
|
assert.Equal(t, util.FileNoExistError, err.(*util.Err).Code)
|
||||||
}
|
}
|
||||||
|
|
24
util.go
24
util.go
|
@ -1,8 +1,12 @@
|
||||||
package goadb
|
package goadb
|
||||||
|
|
||||||
import "strings"
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/zach-klippenstein/goadb/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -16,3 +20,21 @@ func containsWhitespace(str string) bool {
|
||||||
func isBlank(str string) bool {
|
func isBlank(str string) bool {
|
||||||
return whitespaceRegex.MatchString(str)
|
return whitespaceRegex.MatchString(str)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func wrapClientError(err error, client interface{}, operation string, args ...interface{}) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if _, ok := err.(*util.Err); !ok {
|
||||||
|
panic("err is not a *util.Err")
|
||||||
|
}
|
||||||
|
|
||||||
|
clientType := reflect.TypeOf(client)
|
||||||
|
|
||||||
|
return &util.Err{
|
||||||
|
Code: err.(*util.Err).Code,
|
||||||
|
Cause: err,
|
||||||
|
Message: fmt.Sprintf("error performing %s on %s", fmt.Sprintf(operation, args...), clientType),
|
||||||
|
Details: client,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
16
util/errcode_string.go
Normal file
16
util/errcode_string.go
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
// generated by stringer -type=ErrCode; DO NOT EDIT
|
||||||
|
|
||||||
|
package util
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
const _ErrCode_name = "AssertionErrorParseErrorServerNotAvailableNetworkErrorAdbErrorDeviceNotFoundFileNoExistError"
|
||||||
|
|
||||||
|
var _ErrCode_index = [...]uint8{0, 14, 24, 42, 54, 62, 76, 92}
|
||||||
|
|
||||||
|
func (i ErrCode) String() string {
|
||||||
|
if i+1 >= ErrCode(len(_ErrCode_index)) {
|
||||||
|
return fmt.Sprintf("ErrCode(%d)", i)
|
||||||
|
}
|
||||||
|
return _ErrCode_name[_ErrCode_index[i]:_ErrCode_index[i+1]]
|
||||||
|
}
|
96
util/error.go
Normal file
96
util/error.go
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
package util
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// Err is the implementation of error that all goadb functions return.
|
||||||
|
type Err struct {
|
||||||
|
// Code is the high-level "type" of error.
|
||||||
|
Code ErrCode
|
||||||
|
// Message is a human-readable description of the error.
|
||||||
|
Message string
|
||||||
|
// Details is optional, and can be used to associate any auxiliary data with an error.
|
||||||
|
Details interface{}
|
||||||
|
// Cause is optional, and points to the more specific error that caused this one.
|
||||||
|
Cause error
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ error = &Err{}
|
||||||
|
|
||||||
|
//go:generate stringer -type=ErrCode
|
||||||
|
type ErrCode byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
AssertionError ErrCode = iota
|
||||||
|
ParseError ErrCode = iota
|
||||||
|
// The server was not available on the request port and could not be started.
|
||||||
|
ServerNotAvailable ErrCode = iota
|
||||||
|
// General network error communicating with the server.
|
||||||
|
NetworkError ErrCode = iota
|
||||||
|
// The server returned an error message, but we couldn't parse it.
|
||||||
|
AdbError ErrCode = iota
|
||||||
|
// The server returned a "device not found" error.
|
||||||
|
DeviceNotFound ErrCode = iota
|
||||||
|
// Tried to perform an operation on a path that doesn't exist on the device.
|
||||||
|
FileNoExistError ErrCode = iota
|
||||||
|
)
|
||||||
|
|
||||||
|
func Errorf(code ErrCode, format string, args ...interface{}) error {
|
||||||
|
return &Err{
|
||||||
|
Code: code,
|
||||||
|
Message: fmt.Sprintf(format, args...),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
WrapErrf returns an *Err that wraps another *Err and has the same ErrCode.
|
||||||
|
Panics if cause is not an *Err.
|
||||||
|
|
||||||
|
To wrap generic errors, use WrapErrorf.
|
||||||
|
*/
|
||||||
|
func WrapErrf(cause error, format string, args ...interface{}) error {
|
||||||
|
if cause == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := cause.(*Err)
|
||||||
|
return &Err{
|
||||||
|
Code: err.Code,
|
||||||
|
Message: fmt.Sprintf(format, args...),
|
||||||
|
Cause: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
WrapErrorf returns an *Err that wraps another arbitrary error with an ErrCode and a message.
|
||||||
|
|
||||||
|
If cause is nil, returns nil, so you can use it like
|
||||||
|
return util.WrapErrorf(DoSomethingDangerous(), util.NetworkError, "well that didn't work")
|
||||||
|
|
||||||
|
If cause is known to be of type *Err, use WrapErrf.
|
||||||
|
*/
|
||||||
|
func WrapErrorf(cause error, code ErrCode, format string, args ...interface{}) error {
|
||||||
|
if cause == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Err{
|
||||||
|
Code: code,
|
||||||
|
Message: fmt.Sprintf(format, args...),
|
||||||
|
Cause: cause,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AssertionErrorf(format string, args ...interface{}) error {
|
||||||
|
return &Err{
|
||||||
|
Code: AssertionError,
|
||||||
|
Message: fmt.Sprintf(format, args...),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err *Err) Error() string {
|
||||||
|
msg := fmt.Sprintf("%s: %s", err.Code, err.Message)
|
||||||
|
if err.Details != nil {
|
||||||
|
msg = fmt.Sprintf("%s (%+v)", msg, err.Details)
|
||||||
|
}
|
||||||
|
return msg
|
||||||
|
}
|
|
@ -1,24 +0,0 @@
|
||||||
package wire
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
type AdbServerError struct {
|
|
||||||
Request string
|
|
||||||
ServerMsg string
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ error = &AdbServerError{}
|
|
||||||
|
|
||||||
func (e *AdbServerError) Error() string {
|
|
||||||
if e.Request == "" {
|
|
||||||
return fmt.Sprintf("server error: %s", e.ServerMsg)
|
|
||||||
} else {
|
|
||||||
return fmt.Sprintf("server error for %s request: %s", e.Request, e.ServerMsg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func incompleteMessage(description string, actual int, expected int) error {
|
|
||||||
return fmt.Errorf("incomplete %s: read %d bytes, expecting %d", description, actual, expected)
|
|
||||||
}
|
|
20
wire/conn.go
20
wire/conn.go
|
@ -1,5 +1,7 @@
|
||||||
package wire
|
package wire
|
||||||
|
|
||||||
|
import "github.com/zach-klippenstein/goadb/util"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// The official implementation of adb imposes an undocumented 255-byte limit
|
// The official implementation of adb imposes an undocumented 255-byte limit
|
||||||
// on messages.
|
// on messages.
|
||||||
|
@ -57,8 +59,20 @@ func (conn *Conn) RoundTripSingleResponse(req []byte) (resp []byte, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) Close() error {
|
func (conn *Conn) Close() error {
|
||||||
if err := conn.Sender.Close(); err != nil {
|
errs := struct {
|
||||||
return err
|
SenderErr error
|
||||||
|
ScannerErr error
|
||||||
|
}{
|
||||||
|
SenderErr: conn.Sender.Close(),
|
||||||
|
ScannerErr: conn.Scanner.Close(),
|
||||||
}
|
}
|
||||||
return conn.Scanner.Close()
|
|
||||||
|
if errs.ScannerErr != nil || errs.SenderErr != nil {
|
||||||
|
return &util.Err{
|
||||||
|
Code: util.NetworkError,
|
||||||
|
Message: "error closing connection",
|
||||||
|
Details: errs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
package wire
|
package wire
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/zach-klippenstein/goadb/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StatusCodes are returned by the server. If the code indicates failure, the
|
// StatusCodes are returned by the server. If the code indicates failure, the
|
||||||
|
@ -54,33 +55,41 @@ func ReadMessageString(s Scanner) (string, error) {
|
||||||
func (s *realScanner) ReadStatus() (StatusCode, error) {
|
func (s *realScanner) ReadStatus() (StatusCode, error) {
|
||||||
status := make([]byte, 4)
|
status := make([]byte, 4)
|
||||||
n, err := io.ReadFull(s.reader, status)
|
n, err := io.ReadFull(s.reader, status)
|
||||||
|
|
||||||
if err != nil && err != io.ErrUnexpectedEOF {
|
if err != nil && err != io.ErrUnexpectedEOF {
|
||||||
return "", err
|
return "", util.WrapErrorf(err, util.NetworkError, "error reading status")
|
||||||
} else if err == io.ErrUnexpectedEOF {
|
} else if err == io.ErrUnexpectedEOF {
|
||||||
return StatusCode(status), incompleteMessage("status", n, 4)
|
return StatusCode(status), errIncompleteMessage("status", n, 4)
|
||||||
}
|
}
|
||||||
|
|
||||||
return StatusCode(status), nil
|
return StatusCode(status), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *realScanner) ReadMessage() ([]byte, error) {
|
func (s *realScanner) ReadMessage() ([]byte, error) {
|
||||||
|
var err error
|
||||||
|
|
||||||
length, err := s.readLength()
|
length, err := s.readLength()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.WrapErrorf(err, util.NetworkError, "error reading message length")
|
||||||
}
|
}
|
||||||
|
|
||||||
data := make([]byte, length)
|
data := make([]byte, length)
|
||||||
n, err := io.ReadFull(s.reader, data)
|
n, err := io.ReadFull(s.reader, data)
|
||||||
|
|
||||||
if err != nil && err != io.ErrUnexpectedEOF {
|
if err != nil && err != io.ErrUnexpectedEOF {
|
||||||
return data, fmt.Errorf("error reading message data: %v", err)
|
return data, util.WrapErrorf(err, util.NetworkError, "error reading message data")
|
||||||
} else if err == io.ErrUnexpectedEOF {
|
} else if err == io.ErrUnexpectedEOF {
|
||||||
return data, incompleteMessage("message data", n, length)
|
return data, errIncompleteMessage("message data", n, length)
|
||||||
}
|
}
|
||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *realScanner) ReadUntilEof() ([]byte, error) {
|
func (s *realScanner) ReadUntilEof() ([]byte, error) {
|
||||||
return ioutil.ReadAll(s.reader)
|
data, err := ioutil.ReadAll(s.reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, util.WrapErrorf(err, util.NetworkError, "error reading until EOF")
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *realScanner) NewSyncScanner() SyncScanner {
|
func (s *realScanner) NewSyncScanner() SyncScanner {
|
||||||
|
@ -88,21 +97,21 @@ func (s *realScanner) NewSyncScanner() SyncScanner {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *realScanner) Close() error {
|
func (s *realScanner) Close() error {
|
||||||
return s.reader.Close()
|
return util.WrapErrorf(s.reader.Close(), util.NetworkError, "error closing scanner")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *realScanner) readLength() (int, error) {
|
func (s *realScanner) readLength() (int, error) {
|
||||||
lengthHex := make([]byte, 4)
|
lengthHex := make([]byte, 4)
|
||||||
n, err := io.ReadFull(s.reader, lengthHex)
|
n, err := io.ReadFull(s.reader, lengthHex)
|
||||||
if err != nil && err != io.ErrUnexpectedEOF {
|
if err != nil && err != io.ErrUnexpectedEOF {
|
||||||
return 0, err
|
return 0, util.WrapErrorf(err, util.NetworkError, "error reading length")
|
||||||
} else if err == io.ErrUnexpectedEOF {
|
} else if err == io.ErrUnexpectedEOF {
|
||||||
return 0, incompleteMessage("length", n, 4)
|
return 0, errIncompleteMessage("length", n, 4)
|
||||||
}
|
}
|
||||||
|
|
||||||
length, err := strconv.ParseInt(string(lengthHex), 16, 64)
|
length, err := strconv.ParseInt(string(lengthHex), 16, 64)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("invalid hex length: %v", err)
|
return 0, util.WrapErrorf(err, util.NetworkError, "could not parse hex length %v", lengthHex)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clip the length to 255, as per the Google implementation.
|
// Clip the length to 255, as per the Google implementation.
|
||||||
|
|
|
@ -3,6 +3,8 @@ package wire
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
"github.com/zach-klippenstein/goadb/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Sender sends messages to the server.
|
// Sender sends messages to the server.
|
||||||
|
@ -28,7 +30,7 @@ func SendMessageString(s Sender, msg string) error {
|
||||||
|
|
||||||
func (s *realSender) SendMessage(msg []byte) error {
|
func (s *realSender) SendMessage(msg []byte) error {
|
||||||
if len(msg) > MaxMessageLength {
|
if len(msg) > MaxMessageLength {
|
||||||
return fmt.Errorf("message length exceeds maximum: %d", len(msg))
|
return util.AssertionErrorf("message length exceeds maximum: %d", len(msg))
|
||||||
}
|
}
|
||||||
|
|
||||||
lengthAndMsg := fmt.Sprintf("%04x%s", len(msg), msg)
|
lengthAndMsg := fmt.Sprintf("%04x%s", len(msg), msg)
|
||||||
|
@ -40,7 +42,7 @@ func (s *realSender) NewSyncSender() SyncSender {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *realSender) Close() error {
|
func (s *realSender) Close() error {
|
||||||
return s.writer.Close()
|
return util.WrapErrorf(s.writer.Close(), util.NetworkError, "error closing sender")
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Sender = &realSender{}
|
var _ Sender = &realSender{}
|
||||||
|
|
|
@ -2,10 +2,11 @@ package wire
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/zach-klippenstein/goadb/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SyncScanner interface {
|
type SyncScanner interface {
|
||||||
|
@ -38,10 +39,10 @@ func NewSyncScanner(r io.Reader) SyncScanner {
|
||||||
func RequireOctetString(s SyncScanner, expected string) error {
|
func RequireOctetString(s SyncScanner, expected string) error {
|
||||||
actual, err := s.ReadOctetString()
|
actual, err := s.ReadOctetString()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("expected to read '%s', got err: %v", expected, err)
|
return util.WrapErrorf(err, util.NetworkError, "expected to read '%s'", expected)
|
||||||
}
|
}
|
||||||
if actual != expected {
|
if actual != expected {
|
||||||
return fmt.Errorf("expected to read '%s', got '%s'", expected, actual)
|
return util.AssertionErrorf("expected to read '%s', got '%s'", expected, actual)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -49,10 +50,11 @@ func RequireOctetString(s SyncScanner, expected string) error {
|
||||||
func (s *realSyncScanner) ReadOctetString() (string, error) {
|
func (s *realSyncScanner) ReadOctetString() (string, error) {
|
||||||
octet := make([]byte, 4)
|
octet := make([]byte, 4)
|
||||||
n, err := io.ReadFull(s.Reader, octet)
|
n, err := io.ReadFull(s.Reader, octet)
|
||||||
|
|
||||||
if err != nil && err != io.ErrUnexpectedEOF {
|
if err != nil && err != io.ErrUnexpectedEOF {
|
||||||
return "", err
|
return "", util.WrapErrorf(err, util.NetworkError, "error reading octet string from sync scanner")
|
||||||
} else if err == io.ErrUnexpectedEOF {
|
} else if err == io.ErrUnexpectedEOF {
|
||||||
return "", incompleteMessage("octet", n, 4)
|
return "", errIncompleteMessage("octet", n, 4)
|
||||||
}
|
}
|
||||||
|
|
||||||
return string(octet), nil
|
return string(octet), nil
|
||||||
|
@ -60,20 +62,21 @@ func (s *realSyncScanner) ReadOctetString() (string, error) {
|
||||||
func (s *realSyncScanner) ReadInt32() (int32, error) {
|
func (s *realSyncScanner) ReadInt32() (int32, error) {
|
||||||
var value int32
|
var value int32
|
||||||
err := binary.Read(s.Reader, binary.LittleEndian, &value)
|
err := binary.Read(s.Reader, binary.LittleEndian, &value)
|
||||||
return value, err
|
return value, util.WrapErrorf(err, util.NetworkError, "error reading int from sync scanner")
|
||||||
}
|
}
|
||||||
func (s *realSyncScanner) ReadFileMode() (filemode os.FileMode, err error) {
|
func (s *realSyncScanner) ReadFileMode() (os.FileMode, error) {
|
||||||
var value uint32
|
var value uint32
|
||||||
err = binary.Read(s.Reader, binary.LittleEndian, &value)
|
err := binary.Read(s.Reader, binary.LittleEndian, &value)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
filemode = ParseFileModeFromAdb(value)
|
return 0, util.WrapErrorf(err, util.NetworkError, "error reading filemode from sync scanner")
|
||||||
}
|
}
|
||||||
return
|
return ParseFileModeFromAdb(value), nil
|
||||||
|
|
||||||
}
|
}
|
||||||
func (s *realSyncScanner) ReadTime() (time.Time, error) {
|
func (s *realSyncScanner) ReadTime() (time.Time, error) {
|
||||||
seconds, err := s.ReadInt32()
|
seconds, err := s.ReadInt32()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return time.Time{}, err
|
return time.Time{}, util.WrapErrorf(err, util.NetworkError, "error reading time from sync scanner")
|
||||||
}
|
}
|
||||||
|
|
||||||
return time.Unix(int64(seconds), 0).UTC(), nil
|
return time.Unix(int64(seconds), 0).UTC(), nil
|
||||||
|
@ -82,15 +85,15 @@ func (s *realSyncScanner) ReadTime() (time.Time, error) {
|
||||||
func (s *realSyncScanner) ReadString() (string, error) {
|
func (s *realSyncScanner) ReadString() (string, error) {
|
||||||
length, err := s.ReadInt32()
|
length, err := s.ReadInt32()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", util.WrapErrorf(err, util.NetworkError, "error reading length from sync scanner")
|
||||||
}
|
}
|
||||||
|
|
||||||
bytes := make([]byte, length)
|
bytes := make([]byte, length)
|
||||||
n, err := io.ReadFull(s.Reader, bytes)
|
n, rawErr := io.ReadFull(s.Reader, bytes)
|
||||||
if err != nil && err != io.ErrUnexpectedEOF {
|
if rawErr != nil && rawErr != io.ErrUnexpectedEOF {
|
||||||
return "", err
|
return "", util.WrapErrorf(rawErr, util.NetworkError, "error reading string from sync scanner")
|
||||||
} else if err == io.ErrUnexpectedEOF {
|
} else if rawErr == io.ErrUnexpectedEOF {
|
||||||
return "", incompleteMessage("bytes", n, int(length))
|
return "", errIncompleteMessage("bytes", n, int(length))
|
||||||
}
|
}
|
||||||
|
|
||||||
return string(bytes), nil
|
return string(bytes), nil
|
||||||
|
@ -98,7 +101,7 @@ func (s *realSyncScanner) ReadString() (string, error) {
|
||||||
func (s *realSyncScanner) ReadBytes() (io.Reader, error) {
|
func (s *realSyncScanner) ReadBytes() (io.Reader, error) {
|
||||||
length, err := s.ReadInt32()
|
length, err := s.ReadInt32()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, util.WrapErrorf(err, util.NetworkError, "error reading bytes from sync scanner")
|
||||||
}
|
}
|
||||||
|
|
||||||
return io.LimitReader(s.Reader, int64(length)), nil
|
return io.LimitReader(s.Reader, int64(length)), nil
|
||||||
|
@ -106,7 +109,7 @@ func (s *realSyncScanner) ReadBytes() (io.Reader, error) {
|
||||||
|
|
||||||
func (s *realSyncScanner) Close() error {
|
func (s *realSyncScanner) Close() error {
|
||||||
if closer, ok := s.Reader.(io.Closer); ok {
|
if closer, ok := s.Reader.(io.Closer); ok {
|
||||||
return closer.Close()
|
return util.WrapErrorf(closer.Close(), util.NetworkError, "error closing sync scanner")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,10 +2,11 @@ package wire
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/zach-klippenstein/goadb/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SyncSender interface {
|
type SyncSender interface {
|
||||||
|
@ -29,21 +30,28 @@ func NewSyncSender(w io.Writer) SyncSender {
|
||||||
|
|
||||||
func (s *realSyncSender) SendOctetString(str string) error {
|
func (s *realSyncSender) SendOctetString(str string) error {
|
||||||
if len(str) != 4 {
|
if len(str) != 4 {
|
||||||
return fmt.Errorf("octet string must be exactly 4 bytes: '%s'", str)
|
return util.AssertionErrorf("octet string must be exactly 4 bytes: '%s'", str)
|
||||||
}
|
}
|
||||||
return writeFully(s.Writer, []byte(str))
|
|
||||||
|
wrappedErr := util.WrapErrorf(writeFully(s.Writer, []byte(str)),
|
||||||
|
util.NetworkError, "error sending octet string on sync sender")
|
||||||
|
|
||||||
|
return wrappedErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *realSyncSender) SendInt32(val int32) error {
|
func (s *realSyncSender) SendInt32(val int32) error {
|
||||||
return binary.Write(s.Writer, binary.LittleEndian, val)
|
return util.WrapErrorf(binary.Write(s.Writer, binary.LittleEndian, val),
|
||||||
|
util.NetworkError, "error sending int on sync sender")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *realSyncSender) SendFileMode(mode os.FileMode) error {
|
func (s *realSyncSender) SendFileMode(mode os.FileMode) error {
|
||||||
return binary.Write(s.Writer, binary.LittleEndian, mode)
|
return util.WrapErrorf(binary.Write(s.Writer, binary.LittleEndian, mode),
|
||||||
|
util.NetworkError, "error sending filemode on sync sender")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *realSyncSender) SendTime(t time.Time) error {
|
func (s *realSyncSender) SendTime(t time.Time) error {
|
||||||
return s.SendInt32(int32(t.Unix()))
|
return util.WrapErrorf(s.SendInt32(int32(t.Unix())),
|
||||||
|
util.NetworkError, "error sending time on sync sender")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *realSyncSender) SendString(str string) error {
|
func (s *realSyncSender) SendString(str string) error {
|
||||||
|
@ -51,11 +59,12 @@ func (s *realSyncSender) SendString(str string) error {
|
||||||
if length > MaxChunkSize {
|
if length > MaxChunkSize {
|
||||||
// This limit might not apply to filenames, but it's big enough
|
// This limit might not apply to filenames, but it's big enough
|
||||||
// that I don't think it will be a problem.
|
// that I don't think it will be a problem.
|
||||||
return fmt.Errorf("str must be <= %d in length", MaxChunkSize)
|
return util.AssertionErrorf("str must be <= %d in length", MaxChunkSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.SendInt32(int32(length)); err != nil {
|
if err := s.SendInt32(int32(length)); err != nil {
|
||||||
return err
|
return util.WrapErrorf(err, util.NetworkError, "error sending string length on sync sender")
|
||||||
}
|
}
|
||||||
return writeFully(s.Writer, []byte(str))
|
return util.WrapErrorf(writeFully(s.Writer, []byte(str)),
|
||||||
|
util.NetworkError, "error sending string on sync sender")
|
||||||
}
|
}
|
||||||
|
|
64
wire/util.go
64
wire/util.go
|
@ -3,41 +3,85 @@ package wire
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
"github.com/zach-klippenstein/goadb/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrorResponseDetails is an error message returned by the server for a particular request.
|
||||||
|
type ErrorResponseDetails struct {
|
||||||
|
Request string
|
||||||
|
ServerMsg string
|
||||||
|
}
|
||||||
|
|
||||||
// Reads the status, and if failure, reads the message and returns it as an error.
|
// Reads the status, and if failure, reads the message and returns it as an error.
|
||||||
// If the status is success, doesn't read the message.
|
// If the status is success, doesn't read the message.
|
||||||
// req is just used to populate the AdbError, and can be nil.
|
// req is just used to populate the AdbError, and can be nil.
|
||||||
func ReadStatusFailureAsError(s Scanner, req string) error {
|
func ReadStatusFailureAsError(s Scanner, req string) error {
|
||||||
status, err := s.ReadStatus()
|
status, err := s.ReadStatus()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error reading status for %s: %+v", req, err)
|
return util.WrapErrorf(err, util.NetworkError, "error reading status for %s", req)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !status.IsSuccess() {
|
if !status.IsSuccess() {
|
||||||
msg, err := s.ReadMessage()
|
msg, err := s.ReadMessage()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("server returned error for %s, but couldn't read the error message: %+v", err)
|
return util.WrapErrorf(err, util.NetworkError,
|
||||||
|
"server returned error for %s, but couldn't read the error message", req)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &AdbServerError{
|
return adbServerError(req, string(msg))
|
||||||
Request: req,
|
|
||||||
ServerMsg: string(msg),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func adbServerError(request string, serverMsg string) error {
|
||||||
|
var msg string
|
||||||
|
if request == "" {
|
||||||
|
msg = fmt.Sprintf("server error: %s", serverMsg)
|
||||||
|
} else {
|
||||||
|
msg = fmt.Sprintf("server error for %s request: %s", request, serverMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
errCode := util.AdbError
|
||||||
|
if serverMsg == "device not found" {
|
||||||
|
errCode = util.DeviceNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return &util.Err{
|
||||||
|
Code: errCode,
|
||||||
|
Message: msg,
|
||||||
|
Details: ErrorResponseDetails{
|
||||||
|
Request: request,
|
||||||
|
ServerMsg: serverMsg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func errIncompleteMessage(description string, actual int, expected int) error {
|
||||||
|
return &util.Err{
|
||||||
|
Code: util.NetworkError,
|
||||||
|
Message: fmt.Sprintf("incomplete %s: read %d bytes, expecting %d", description, actual, expected),
|
||||||
|
Details: struct {
|
||||||
|
ActualReadBytes int
|
||||||
|
ExpectedBytes int
|
||||||
|
}{
|
||||||
|
ActualReadBytes: actual,
|
||||||
|
ExpectedBytes: expected,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// writeFully writes all of data to w.
|
// writeFully writes all of data to w.
|
||||||
// Inverse of io.ReadFully().
|
// Inverse of io.ReadFully().
|
||||||
func writeFully(w io.Writer, data []byte) error {
|
func writeFully(w io.Writer, data []byte) error {
|
||||||
for len(data) > 0 {
|
offset := 0
|
||||||
n, err := w.Write(data)
|
for offset < len(data) {
|
||||||
|
n, err := w.Write(data[offset:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return util.WrapErrorf(err, util.NetworkError, "error writing %d bytes at offset %d", len(data), offset)
|
||||||
}
|
}
|
||||||
data = data[n:]
|
offset += n
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue