Started host and local service clients.

Milestone: the demo app prints /proc/loadavg from the device.
This commit is contained in:
Zach Klippenstein 2015-04-12 13:34:20 -07:00
parent f55ab62f4c
commit 4b9891533a
28 changed files with 1470 additions and 156 deletions

View file

@ -1,15 +1,22 @@
// An app demonstrating most of the library's features.
package main
import (
"flag"
"fmt"
"io/ioutil"
"log"
adb "github.com/zach-klippenstein/goadb"
"github.com/zach-klippenstein/goadb/wire"
)
var port = flag.Int("p", wire.AdbPort, "")
func main() {
client := &adb.HostClient{wire.Dial}
flag.Parse()
client := adb.NewHostClientDialer(wire.NewDialer("", *port))
fmt.Println("Starting server…")
client.StartServer()
@ -28,6 +35,84 @@ func main() {
fmt.Printf("\t%+v\n", *device)
}
fmt.Println("Killing server…")
client.KillServer()
PrintDeviceInfoAndError(client.GetAnyDevice())
PrintDeviceInfoAndError(client.GetLocalDevice())
PrintDeviceInfoAndError(client.GetUsbDevice())
serials, err := client.ListDeviceSerials()
if err != nil {
log.Fatal(err)
}
for _, serial := range serials {
PrintDeviceInfoAndError(client.GetDeviceWithSerial(serial))
}
//fmt.Println("Killing server…")
//client.KillServer()
}
func PrintDeviceInfoAndError(device *adb.DeviceClient) {
if err := PrintDeviceInfo(device); err != nil {
log.Println(err)
}
}
func PrintDeviceInfo(device *adb.DeviceClient) error {
serialNo, err := device.GetSerial()
if err != nil {
return err
}
devPath, err := device.GetDevicePath()
if err != nil {
return err
}
state, err := device.GetState()
if err != nil {
return err
}
fmt.Println(device)
fmt.Printf("\tserial no: %s\n", serialNo)
fmt.Printf("\tdevPath: %s\n", devPath)
fmt.Printf("\tstate: %s\n", state)
cmdOutput, err := device.RunCommand("pwd")
if err != nil {
fmt.Println("\terror running command:", err)
}
fmt.Printf("\tcmd output: %s\n", cmdOutput)
stat, err := device.Stat("/sdcard")
if err != nil {
fmt.Println("\terror stating /sdcard:", err)
}
fmt.Printf("\tstat \"/sdcard\": %+v\n", stat)
fmt.Println("\tfiles in \"/\":")
entries, err := device.ListDirEntries("/")
if err != nil {
fmt.Println("\terror listing files:", err)
} else {
for entries.Next() {
fmt.Printf("\t%+v\n", *entries.Entry())
}
if entries.Err() != nil {
fmt.Println("\terror listing files:", err)
}
}
fmt.Print("\tload avg: ")
loadavgReader, err := device.OpenRead("/proc/loadavg")
if err != nil {
fmt.Println("\terror opening file:", err)
} else {
loadAvg, err := ioutil.ReadAll(loadavgReader)
if err != nil {
fmt.Println("\terror reading file:", err)
} else {
fmt.Println(string(loadAvg))
}
}
return nil
}

View file

@ -1,3 +1,4 @@
// A simple tool for sending raw messages to an adb server.
package main
import (
@ -47,7 +48,7 @@ func readLine() string {
}
func doCommand(cmd string) error {
conn, err := wire.DialPort(*port)
conn, err := wire.NewDialer("", *port).Dial()
if err != nil {
log.Fatal(err)
}

208
device_client.go Normal file
View file

@ -0,0 +1,208 @@
package goadb
import (
"fmt"
"io"
"strings"
"github.com/zach-klippenstein/goadb/wire"
)
/*
DeviceClient communicates with a specific Android device.
*/
type DeviceClient struct {
dialer nilSafeDialer
descriptor *DeviceDescriptor
}
func (c *DeviceClient) String() string {
return c.descriptor.String()
}
// get-product is documented, but not implemented in the server.
// TODO(z): Make getProduct exported if get-product is ever implemented in adb.
func (c *DeviceClient) getProduct() (string, error) {
return c.getAttribute("get-product")
}
func (c *DeviceClient) GetSerial() (string, error) {
return c.getAttribute("get-serialno")
}
func (c *DeviceClient) GetDevicePath() (string, error) {
return c.getAttribute("get-devpath")
}
func (c *DeviceClient) GetState() (string, error) {
return c.getAttribute("get-state")
}
/*
RunCommand runs the specified commands on a shell on the device.
From the Android docs:
Run 'command arg1 arg2 ...' in a shell on the device, and return
its output and error streams. Note that arguments must be separated
by spaces. If an argument contains a space, it must be quoted with
double-quotes. Arguments cannot contain double quotes or things
will go very wrong.
Note that this is the non-interactive version of "adb shell"
Source: https://android.googlesource.com/platform/system/core/+/master/adb/SERVICES.TXT
This method quotes the arguments for you, and will return an error if any of them
contain double quotes.
*/
func (c *DeviceClient) RunCommand(cmd string, args ...string) (string, error) {
cmd, err := prepareCommandLine(cmd, args...)
if err != nil {
return "", err
}
conn, err := c.dialDevice()
if err != nil {
return "", err
}
defer conn.Close()
req := fmt.Sprintf("shell:%s", cmd)
// Shell responses are special, they don't include a length header.
// We read until the stream is closed.
// So, we can't use conn.RoundTripSingleResponse.
if err = conn.SendMessage([]byte(req)); err != nil {
return "", err
}
if err = wire.ReadStatusFailureAsError(conn, []byte(req)); err != nil {
return "", err
}
resp, err := conn.ReadUntilEof()
if err != nil {
return "", err
}
return string(resp), nil
}
/*
Remount, from the docs,
Ask adbd to remount the device's filesystem in read-write mode,
instead of read-only. This is usually necessary before performing
an "adb sync" or "adb push" request.
This request may not succeed on certain builds which do not allow
that.
Source: https://android.googlesource.com/platform/system/core/+/master/adb/SERVICES.TXT
*/
func (c *DeviceClient) Remount() (string, error) {
conn, err := c.dialDevice()
if err != nil {
return "", err
}
defer conn.Close()
resp, err := conn.RoundTripSingleResponse([]byte("remount"))
return string(resp), err
}
func (c *DeviceClient) ListDirEntries(path string) (*DirEntries, error) {
conn, err := c.getSyncConn()
if err != nil {
return nil, err
}
return listDirEntries(conn, path)
}
func (c *DeviceClient) Stat(path string) (*DirEntry, error) {
conn, err := c.getSyncConn()
if err != nil {
return nil, err
}
return stat(conn, path)
}
func (c *DeviceClient) OpenRead(path string) (io.ReadCloser, error) {
conn, err := c.getSyncConn()
if err != nil {
return nil, err
}
return receiveFile(conn, path)
}
// getAttribute returns the first message returned by the server by running
// <host-prefix>:<attr>, where host-prefix is determined from the DeviceDescriptor.
func (c *DeviceClient) getAttribute(attr string) (string, error) {
resp, err := wire.RoundTripSingleResponse(c.dialer,
fmt.Sprintf("%s:%s", c.descriptor.getHostPrefix(), attr))
if err != nil {
return "", err
}
return string(resp), nil
}
// dialDevice switches the connection to communicate directly with the device
// by requesting the transport defined by the DeviceDescriptor.
func (c *DeviceClient) dialDevice() (*wire.Conn, error) {
conn, err := c.dialer.Dial()
if err != nil {
return nil, err
}
req := fmt.Sprintf("host:%s", c.descriptor.getTransportDescriptor())
if err = wire.SendMessageString(conn, req); err != nil {
conn.Close()
return nil, err
}
if err = wire.ReadStatusFailureAsError(conn, []byte(req)); err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
func (c *DeviceClient) getSyncConn() (*wire.SyncConn, error) {
conn, err := c.dialDevice()
if err != nil {
return nil, err
}
// Switch the connection to sync mode.
if err := wire.SendMessageString(conn, "sync:"); err != nil {
return nil, err
}
if err := wire.ReadStatusFailureAsError(conn, []byte("sync")); err != nil {
return nil, err
}
return conn.NewSyncConn(), nil
}
// prepareCommandLine validates the command and argument strings, quotes
// arguments if required, and joins them into a valid adb command string.
func prepareCommandLine(cmd string, args ...string) (string, error) {
if isBlank(cmd) {
return "", fmt.Errorf("command cannot be empty")
}
for i, arg := range args {
if strings.ContainsRune(arg, '"') {
return "", fmt.Errorf("arg at index %d contains an invalid double quote: %s", i, arg)
}
if containsWhitespace(arg) {
args[i] = fmt.Sprintf("\"%s\"", arg)
}
}
// Prepend the comand to the args array.
if len(args) > 0 {
cmd = fmt.Sprintf("%s %s", cmd, strings.Join(args, " "))
}
return cmd, nil
}

68
device_client_test.go Normal file
View file

@ -0,0 +1,68 @@
package goadb
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zach-klippenstein/goadb/wire"
)
func TestGetAttribute(t *testing.T) {
s := &MockServer{
Status: wire.StatusSuccess,
Messages: []string{"value"},
}
client := &DeviceClient{nilSafeDialer{s}, deviceWithSerial("serial")}
v, err := client.getAttribute("attr")
assert.Equal(t, "host-serial:serial:attr", s.Requests[0])
assert.NoError(t, err)
assert.Equal(t, "value", v)
}
func TestRunCommandNoArgs(t *testing.T) {
s := &MockServer{
Status: wire.StatusSuccess,
Messages: []string{"output"},
}
client := &DeviceClient{nilSafeDialer{s}, anyDevice()}
v, err := client.RunCommand("cmd")
assert.Equal(t, "host:transport-any", s.Requests[0])
assert.Equal(t, "shell:cmd", s.Requests[1])
assert.NoError(t, err)
assert.Equal(t, "output", v)
}
func TestPrepareCommandLineNoArgs(t *testing.T) {
result, err := prepareCommandLine("cmd")
assert.NoError(t, err)
assert.Equal(t, "cmd", result)
}
func TestPrepareCommandLineEmptyCommand(t *testing.T) {
_, err := prepareCommandLine("")
assert.EqualError(t, err, "command cannot be empty")
}
func TestPrepareCommandLineBlankCommand(t *testing.T) {
_, err := prepareCommandLine(" ")
assert.EqualError(t, err, "command cannot be empty")
}
func TestPrepareCommandLineCleanArgs(t *testing.T) {
result, err := prepareCommandLine("cmd", "arg1", "arg2")
assert.NoError(t, err)
assert.Equal(t, "cmd arg1 arg2", result)
}
func TestPrepareCommandLineArgWithWhitespaceQuotes(t *testing.T) {
result, err := prepareCommandLine("cmd", "arg with spaces")
assert.NoError(t, err)
assert.Equal(t, "cmd \"arg with spaces\"", result)
}
func TestPrepareCommandLineArgWithDoubleQuoteFails(t *testing.T) {
_, err := prepareCommandLine("cmd", "quoted\"arg")
assert.EqualError(t, err, "arg at index 0 contains an invalid double quote: quoted\"arg")
}

80
device_descriptor.go Normal file
View file

@ -0,0 +1,80 @@
package goadb
import "fmt"
//go:generate stringer -type=deviceDescriptorType
type deviceDescriptorType int
const (
// host:transport-any and host:<request>
DeviceAny deviceDescriptorType = iota
// host:transport:<serial> and host-serial:<serial>:<request>
DeviceSerial
// host:transport-usb and host-usb:<request>
DeviceUsb
// host:transport-local and host-local:<request>
DeviceLocal
)
type DeviceDescriptor struct {
descriptorType deviceDescriptorType
// Only used if Type is DeviceSerial.
serial string
}
func anyDevice() *DeviceDescriptor {
return &DeviceDescriptor{descriptorType: DeviceAny}
}
func anyUsbDevice() *DeviceDescriptor {
return &DeviceDescriptor{descriptorType: DeviceUsb}
}
func anyLocalDevice() *DeviceDescriptor {
return &DeviceDescriptor{descriptorType: DeviceLocal}
}
func deviceWithSerial(serial string) *DeviceDescriptor {
return &DeviceDescriptor{
descriptorType: DeviceSerial,
serial: serial,
}
}
func (d *DeviceDescriptor) String() string {
if d.descriptorType == DeviceSerial {
return fmt.Sprintf("%s[%s]", d.descriptorType, d.serial)
}
return d.descriptorType.String()
}
func (d *DeviceDescriptor) getHostPrefix() string {
switch d.descriptorType {
case DeviceAny:
return "host"
case DeviceUsb:
return "host-usb"
case DeviceLocal:
return "host-local"
case DeviceSerial:
return fmt.Sprintf("host-serial:%s", d.serial)
default:
panic(fmt.Sprintf("invalid DeviceDescriptorType: %v", d.descriptorType))
}
}
func (d *DeviceDescriptor) getTransportDescriptor() string {
switch d.descriptorType {
case DeviceAny:
return "transport-any"
case DeviceUsb:
return "transport-usb"
case DeviceLocal:
return "transport-local"
case DeviceSerial:
return fmt.Sprintf("transport:%s", d.serial)
default:
panic(fmt.Sprintf("invalid DeviceDescriptorType: %v", d.descriptorType))
}
}

View file

@ -6,41 +6,40 @@ import (
"strings"
)
// Device represents a connected Android device.
type Device struct {
type DeviceInfo struct {
// Always set.
Serial string
// Product, device, and model are not set in the short form.
Product string
Model string
Device string
Product string
Model string
DeviceInfo string
// Only set for devices connected via USB.
Usb string
}
// IsUsb returns true if the device is connected via USB.
func (d *Device) IsUsb() bool {
func (d *DeviceInfo) IsUsb() bool {
return d.Usb != ""
}
func newDevice(serial string, attrs map[string]string) (*Device, error) {
func newDevice(serial string, attrs map[string]string) (*DeviceInfo, error) {
if serial == "" {
return nil, fmt.Errorf("device serial cannot be blank")
}
return &Device{
Serial: serial,
Product: attrs["product"],
Model: attrs["model"],
Device: attrs["device"],
Usb: attrs["usb"],
return &DeviceInfo{
Serial: serial,
Product: attrs["product"],
Model: attrs["model"],
DeviceInfo: attrs["device"],
Usb: attrs["usb"],
}, nil
}
func parseDeviceList(list string, lineParseFunc func(string) (*Device, error)) ([]*Device, error) {
var devices []*Device
func parseDeviceList(list string, lineParseFunc func(string) (*DeviceInfo, error)) ([]*DeviceInfo, error) {
var devices []*DeviceInfo
scanner := bufio.NewScanner(strings.NewReader(list))
for scanner.Scan() {
@ -57,7 +56,7 @@ func parseDeviceList(list string, lineParseFunc func(string) (*Device, error)) (
return devices, nil
}
func parseDeviceShort(line string) (*Device, error) {
func parseDeviceShort(line string) (*DeviceInfo, error) {
fields := strings.Fields(line)
if len(fields) != 2 {
return nil, fmt.Errorf("malformed device line, expected 2 fields but found %d", len(fields))
@ -66,7 +65,7 @@ func parseDeviceShort(line string) (*Device, error) {
return newDevice(fields[0], map[string]string{})
}
func parseDeviceLong(line string) (*Device, error) {
func parseDeviceLong(line string) (*DeviceInfo, error) {
fields := strings.Fields(line)
if len(fields) < 5 {
return nil, fmt.Errorf("malformed device line, expected at least 5 fields but found %d", len(fields))

View file

@ -19,27 +19,27 @@ func ParseDeviceList(t *testing.T) {
func TestParseDeviceShort(t *testing.T) {
dev, err := parseDeviceShort("192.168.56.101:5555 device\n")
assert.NoError(t, err)
assert.Equal(t, &Device{
assert.Equal(t, &DeviceInfo{
Serial: "192.168.56.101:5555"}, dev)
}
func TestParseDeviceLong(t *testing.T) {
dev, err := parseDeviceLong("SERIAL device product:PRODUCT model:MODEL device:DEVICE\n")
assert.NoError(t, err)
assert.Equal(t, &Device{
Serial: "SERIAL",
Product: "PRODUCT",
Model: "MODEL",
Device: "DEVICE"}, dev)
assert.Equal(t, &DeviceInfo{
Serial: "SERIAL",
Product: "PRODUCT",
Model: "MODEL",
DeviceInfo: "DEVICE"}, dev)
}
func TestParseDeviceLongUsb(t *testing.T) {
dev, err := parseDeviceLong("SERIAL device usb:1234 product:PRODUCT model:MODEL device:DEVICE \n")
assert.NoError(t, err)
assert.Equal(t, &Device{
Serial: "SERIAL",
Product: "PRODUCT",
Model: "MODEL",
Device: "DEVICE",
Usb: "1234"}, dev)
assert.Equal(t, &DeviceInfo{
Serial: "SERIAL",
Product: "PRODUCT",
Model: "MODEL",
DeviceInfo: "DEVICE",
Usb: "1234"}, dev)
}

View file

@ -0,0 +1,16 @@
// generated by stringer -type=deviceDescriptorType; DO NOT EDIT
package goadb
import "fmt"
const _deviceDescriptorType_name = "DeviceAnyDeviceSerialDeviceUsbDeviceLocal"
var _deviceDescriptorType_index = [...]uint8{0, 9, 21, 30, 41}
func (i deviceDescriptorType) String() string {
if i < 0 || i+1 >= deviceDescriptorType(len(_deviceDescriptorType_index)) {
return fmt.Sprintf("deviceDescriptorType(%d)", i)
}
return _deviceDescriptorType_name[_deviceDescriptorType_index[i]:_deviceDescriptorType_index[i+1]]
}

98
dir_entries.go Normal file
View file

@ -0,0 +1,98 @@
package goadb
import (
"fmt"
"github.com/zach-klippenstein/goadb/wire"
)
// DirEntries iterates over directory entries.
type DirEntries struct {
scanner wire.SyncScanner
// Called when finished iterating (successfully or not).
doneHandler func()
currentEntry *DirEntry
err error
}
func (entries *DirEntries) Next() bool {
if entries.err != nil {
return false
}
entry, done, err := readNextDirListEntry(entries.scanner)
if err != nil {
entries.err = err
entries.onDone()
return false
}
entries.currentEntry = entry
if done {
entries.onDone()
return false
}
return true
}
func (entries *DirEntries) Entry() *DirEntry {
return entries.currentEntry
}
func (entries *DirEntries) Err() error {
return entries.err
}
func (entries *DirEntries) onDone() {
if entries.doneHandler != nil {
entries.doneHandler()
}
}
func readNextDirListEntry(s wire.SyncScanner) (entry *DirEntry, done bool, err error) {
id, err := s.ReadOctetString()
if err != nil {
return
}
if id == "DONE" {
done = true
return
} else if id != "DENT" {
err = fmt.Errorf("expected dir entry ID 'DENT', but got '%s'", id)
return
}
mode, err := s.ReadFileMode()
if err != nil {
err = fmt.Errorf("error reading file mode: %v", err)
return
}
size, err := s.ReadInt32()
if err != nil {
err = fmt.Errorf("error reading file size: %v", err)
return
}
mtime, err := s.ReadTime()
if err != nil {
err = fmt.Errorf("error reading file time: %v", err)
return
}
name, err := s.ReadString()
if err != nil {
err = fmt.Errorf("error reading file name: %v", err)
return
}
done = false
entry = &DirEntry{
Name: name,
Mode: mode,
Size: size,
ModifiedAt: mtime,
}
return
}

12
doc.go Normal file
View file

@ -0,0 +1,12 @@
/*
Package goadb is a Go interface to the Android Debug Bridge (adb).
See cmd/demo/demo.go for an example of how to use this library.
The client/server spec is defined at https://android.googlesource.com/platform/system/core/+/master/adb/OVERVIEW.TXT.
WARNING This library is under heavy development, and its API is likely to change without notice.
*/
package goadb
// TODO(z): Write method-specific examples.

View file

@ -1,43 +1,40 @@
/*
Package goadb is a Go interface to the Android Debug Bridge (adb).
The client/server spec is defined at https://android.googlesource.com/platform/system/core/+/master/adb/OVERVIEW.TXT.
WARNING This library is under heavy development, and its API is likely to change without notice.
*/
// TODO(z): Implement TrackDevices.
package goadb
import (
"fmt"
"os/exec"
"strconv"
"github.com/zach-klippenstein/goadb/wire"
)
// Dialer is a function that knows how to create a connection to an adb server.
type Dialer func() (*wire.Conn, error)
/*
HostClient interacts with host services on the adb server.
HostClient communicates with host services on the adb server.
Eg.
dialer := &HostClient{wire.Dial}
dialer.GetServerVersion()
TODO make this a real example.
TODO Finish implementing services.
client := NewHostClient()
client.StartServer()
client.ListDevices()
client.GetAnyDevice() // see DeviceClient
See list of services at https://android.googlesource.com/platform/system/core/+/master/adb/SERVICES.TXT.
*/
// TODO(z): Finish implementing host services.
type HostClient struct {
Dialer
dialer nilSafeDialer
}
func NewHostClient() *HostClient {
return NewHostClientDialer(nil)
}
func NewHostClientDialer(d wire.Dialer) *HostClient {
return &HostClient{nilSafeDialer{d}}
}
// GetServerVersion asks the ADB server for its internal version number.
func (c *HostClient) GetServerVersion() (int, error) {
resp, err := c.roundTripSingleResponse([]byte("host:version"))
resp, err := wire.RoundTripSingleResponse(c.dialer, "host:version")
if err != nil {
return 0, err
}
@ -53,13 +50,13 @@ Corresponds to the command:
adb kill-server
*/
func (c *HostClient) KillServer() error {
conn, err := c.Dialer()
conn, err := c.dialer.Dial()
if err != nil {
return err
}
defer conn.Close()
if err = conn.SendMessage([]byte("host:kill")); err != nil {
if err = wire.SendMessageString(conn, "host:kill"); err != nil {
return err
}
@ -84,7 +81,7 @@ Corresponds to the command:
adb devices
*/
func (c *HostClient) ListDeviceSerials() ([]string, error) {
resp, err := c.roundTripSingleResponse([]byte("host:devices"))
resp, err := wire.RoundTripSingleResponse(c.dialer, "host:devices")
if err != nil {
return nil, err
}
@ -107,8 +104,8 @@ ListDevices returns the list of connected devices.
Corresponds to the command:
adb devices -l
*/
func (c *HostClient) ListDevices() ([]*Device, error) {
resp, err := c.roundTripSingleResponse([]byte("host:devices-l"))
func (c *HostClient) ListDevices() ([]*DeviceInfo, error) {
resp, err := wire.RoundTripSingleResponse(c.dialer, "host:devices-l")
if err != nil {
return nil, err
}
@ -116,41 +113,33 @@ func (c *HostClient) ListDevices() ([]*Device, error) {
return parseDeviceList(string(resp), parseDeviceLong)
}
func (c *HostClient) roundTripSingleResponse(req []byte) (resp []byte, err error) {
conn, err := c.Dialer()
if err != nil {
return nil, err
}
defer conn.Close()
if err = conn.SendMessage(req); err != nil {
return nil, err
}
err = c.readStatusFailureAsError(conn)
if err != nil {
return nil, err
}
return conn.ReadMessage()
func (c *HostClient) GetDevice(d *DeviceInfo) *DeviceClient {
return c.GetDeviceWithSerial(d.Serial)
}
// Reads the status, and if failure, reads the message and returns it as an error.
// If the status is success, doesn't read the message.
func (c *HostClient) readStatusFailureAsError(conn *wire.Conn) error {
status, err := conn.ReadStatus()
if err != nil {
return err
}
if !status.IsSuccess() {
msg, err := conn.ReadMessage()
if err != nil {
return err
}
return fmt.Errorf("server error: %s", msg)
}
return nil
// GetDeviceWithSerial returns a client for the device with the specified serial number.
// Will return a client even if there is no matching device connected.
func (c *HostClient) GetDeviceWithSerial(serial string) *DeviceClient {
return c.getDevice(deviceWithSerial(serial))
}
// GetAnyDevice returns a client for any one connected device.
func (c *HostClient) GetAnyDevice() *DeviceClient {
return c.getDevice(anyDevice())
}
// GetUsbDevice returns a client for the USB device.
// Will return a client even if there is no device connected.
func (c *HostClient) GetUsbDevice() *DeviceClient {
return c.getDevice(anyUsbDevice())
}
// GetLocalDevice returns a client for the local device.
// Will return a client even if there is no device connected.
func (c *HostClient) GetLocalDevice() *DeviceClient {
return c.getDevice(anyLocalDevice())
}
func (c *HostClient) getDevice(descriptor *DeviceDescriptor) *DeviceClient {
return &DeviceClient{c.dialer, descriptor}
}

View file

@ -2,6 +2,7 @@ package goadb
import (
"io"
"strings"
"testing"
"github.com/stretchr/testify/assert"
@ -9,29 +10,34 @@ import (
)
func TestGetServerVersion(t *testing.T) {
client := &HostClient{mockDialer(&MockServer{
s := &MockServer{
Status: wire.StatusSuccess,
Messages: []string{"000a"},
})}
}
client := NewHostClientDialer(s)
v, err := client.GetServerVersion()
assert.Equal(t, "host:version", s.Requests[0])
assert.NoError(t, err)
assert.Equal(t, 10, v)
}
func mockDialer(s *MockServer) Dialer {
return func() (*wire.Conn, error) {
return &wire.Conn{s, s, s}, nil
}
}
type MockServer struct {
Status wire.StatusCode
Status wire.StatusCode
// Messages are sent in order, each preceded by a length header.
Messages []string
// Each request is appended to this slice.
Requests []string
nextMsgIndex int
}
func (s *MockServer) Dial() (*wire.Conn, error) {
return wire.NewConn(s, s, s.Close), nil
}
func (s *MockServer) ReadStatus() (wire.StatusCode, error) {
return s.Status, nil
}
@ -45,7 +51,24 @@ func (s *MockServer) ReadMessage() ([]byte, error) {
return []byte(s.Messages[s.nextMsgIndex-1]), nil
}
func (s *MockServer) ReadUntilEof() ([]byte, error) {
var data []string
for ; s.nextMsgIndex < len(s.Messages); s.nextMsgIndex++ {
data = append(data, s.Messages[s.nextMsgIndex])
}
return []byte(strings.Join(data, "")), nil
}
func (s *MockServer) SendMessage(msg []byte) error {
s.Requests = append(s.Requests, string(msg))
return nil
}
func (s *MockServer) NewSyncScanner() wire.SyncScanner {
return nil
}
func (s *MockServer) NewSyncSender() wire.SyncSender {
return nil
}

15
nil_safe_dialer.go Normal file
View file

@ -0,0 +1,15 @@
package goadb
import "github.com/zach-klippenstein/goadb/wire"
type nilSafeDialer struct {
wire.Dialer
}
func (d nilSafeDialer) Dial() (*wire.Conn, error) {
if d.Dialer == nil {
d.Dialer = wire.NewDialer("", 0)
}
return d.Dialer.Dial()
}

92
sync_client.go Normal file
View file

@ -0,0 +1,92 @@
// TODO(z): Implement send.
package goadb
import (
"fmt"
"io"
"os"
"time"
"github.com/zach-klippenstein/goadb/wire"
)
/*
DirEntry holds information about a directory entry on a device.
Unfortunately, adb doesn't seem to set the directory bit for directories.
*/
type DirEntry struct {
Name string
Mode os.FileMode
Size int32
ModifiedAt time.Time
}
func stat(conn *wire.SyncConn, path string) (*DirEntry, error) {
if err := conn.SendOctetString("STAT"); err != nil {
return nil, err
}
if err := conn.SendString(path); err != nil {
return nil, err
}
id, err := conn.ReadOctetString()
if err != nil {
return nil, err
}
if id != "STAT" {
return nil, fmt.Errorf("expected stat ID 'STAT', but got '%s'", id)
}
return readStat(conn)
}
func listDirEntries(conn *wire.SyncConn, path string) (entries *DirEntries, err error) {
if err = conn.SendOctetString("LIST"); err != nil {
return
}
if err = conn.SendString(path); err != nil {
return
}
return &DirEntries{
scanner: conn,
doneHandler: func() { conn.Close() },
}, nil
}
func receiveFile(conn *wire.SyncConn, path string) (io.ReadCloser, error) {
if err := conn.SendOctetString("RECV"); err != nil {
return nil, err
}
if err := conn.SendString(path); err != nil {
return nil, err
}
return newSyncFileReader(conn), nil
}
func readStat(s wire.SyncScanner) (entry *DirEntry, err error) {
mode, err := s.ReadFileMode()
if err != nil {
err = fmt.Errorf("error reading file mode: %v", err)
return
}
size, err := s.ReadInt32()
if err != nil {
err = fmt.Errorf("error reading file size: %v", err)
return
}
mtime, err := s.ReadTime()
if err != nil {
err = fmt.Errorf("error reading file time: %v", err)
return
}
entry = &DirEntry{
Mode: mode,
Size: size,
ModifiedAt: mtime,
}
return
}

2
sync_client_test.go Normal file
View file

@ -0,0 +1,2 @@
// TODO(z): Implement tests for sync_client functions.
package goadb

69
sync_file_reader.go Normal file
View file

@ -0,0 +1,69 @@
package goadb
import (
"fmt"
"io"
"github.com/zach-klippenstein/goadb/wire"
)
// syncFileReader wraps a SyncConn that has requested to receive a file.
type syncFileReader struct {
// Reader used to read data from the adb connection.
scanner wire.SyncScanner
// Reader for the current chunk only.
chunkReader io.Reader
}
var _ io.ReadCloser = &syncFileReader{}
func newSyncFileReader(s wire.SyncScanner) io.ReadCloser {
return &syncFileReader{
scanner: s,
}
}
func (r *syncFileReader) Read(buf []byte) (n int, err error) {
if r.chunkReader == nil {
chunkReader, err := readNextChunk(r.scanner)
if err != nil {
// If this is EOF, we've read the last chunk.
// Either way, we want to pass it up to the caller.
return 0, err
}
r.chunkReader = chunkReader
}
n, err = r.chunkReader.Read(buf)
if err == io.EOF {
// End of current chunk, don't return an error, the next chunk will be
// read on the next call to this method.
r.chunkReader = nil
return n, nil
}
return n, err
}
func (r *syncFileReader) Close() error {
return r.scanner.Close()
}
// readNextChunk creates an io.LimitedReader for the next chunk of data,
// and returns io.EOF if the last chunk has been read.
func readNextChunk(r wire.SyncScanner) (io.Reader, error) {
id, err := r.ReadOctetString()
if err != nil {
return nil, err
}
switch id {
case "DATA":
return r.ReadBytes()
case "DONE":
return nil, io.EOF
default:
return nil, fmt.Errorf("expected chunk id 'DATA', but got '%s'", id)
}
}

82
sync_file_reader_test.go Normal file
View file

@ -0,0 +1,82 @@
package goadb
import (
"io"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zach-klippenstein/goadb/wire"
)
func TestReadNextChunk(t *testing.T) {
s := wire.NewSyncScanner(strings.NewReader(
"DATA\006\000\000\000hello DATA\005\000\000\000worldDONE"))
// Read 1st chunk
reader, err := readNextChunk(s)
assert.NoError(t, err)
assert.Equal(t, 6, reader.(*io.LimitedReader).N)
buf := make([]byte, 10)
n, err := reader.Read(buf)
assert.NoError(t, err)
assert.Equal(t, 6, n)
assert.Equal(t, "hello ", string(buf[:6]))
// Read 2nd chunk
reader, err = readNextChunk(s)
assert.NoError(t, err)
assert.Equal(t, 5, reader.(*io.LimitedReader).N)
buf = make([]byte, 10)
n, err = reader.Read(buf)
assert.NoError(t, err)
assert.Equal(t, 5, n)
assert.Equal(t, "world", string(buf[:5]))
// Read DONE
_, err = readNextChunk(s)
assert.Equal(t, io.EOF, err)
}
func TestReadNextChunkInvalidChunkId(t *testing.T) {
s := wire.NewSyncScanner(strings.NewReader(
"ATAD\006\000\000\000hello "))
// Read 1st chunk
_, err := readNextChunk(s)
assert.EqualError(t, err, "expected chunk id 'DATA', but got 'ATAD'")
}
func TestReadMultipleCalls(t *testing.T) {
s := wire.NewSyncScanner(strings.NewReader(
"DATA\006\000\000\000hello DATA\005\000\000\000worldDONE"))
reader := newSyncFileReader(s)
firstByte := make([]byte, 1)
_, err := io.ReadFull(reader, firstByte)
assert.NoError(t, err)
assert.Equal(t, "h", string(firstByte))
restFirstChunkBytes := make([]byte, 5)
_, err = io.ReadFull(reader, restFirstChunkBytes)
assert.NoError(t, err)
assert.Equal(t, "ello ", string(restFirstChunkBytes))
secondChunkBytes := make([]byte, 5)
_, err = io.ReadFull(reader, secondChunkBytes)
assert.NoError(t, err)
assert.Equal(t, "world", string(secondChunkBytes))
_, err = io.ReadFull(reader, make([]byte, 5))
assert.Equal(t, io.EOF, err)
}
func TestReadAll(t *testing.T) {
s := wire.NewSyncScanner(strings.NewReader(
"DATA\006\000\000\000hello DATA\005\000\000\000worldDONE"))
reader := newSyncFileReader(s)
buf := make([]byte, 20)
_, err := io.ReadFull(reader, buf)
assert.Equal(t, io.ErrUnexpectedEOF, err)
assert.Equal(t, "hello world\000", string(buf[:12]))
}

18
util.go Normal file
View file

@ -0,0 +1,18 @@
package goadb
import "strings"
import (
"regexp"
)
var (
whitespaceRegex = regexp.MustCompile(`^\s*$`)
)
func containsWhitespace(str string) bool {
return strings.ContainsAny(str, " \t\v")
}
func isBlank(str string) bool {
return whitespaceRegex.MatchString(str)
}

27
util_test.go Normal file
View file

@ -0,0 +1,27 @@
package goadb
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestContainsWhitespaceYes(t *testing.T) {
assert.True(t, containsWhitespace("hello world"))
}
func TestContainsWhitespaceNo(t *testing.T) {
assert.False(t, containsWhitespace("hello"))
}
func TestIsBlankWhenEmpty(t *testing.T) {
assert.True(t, isBlank(""))
}
func TestIsBlankWhenJustWhitespace(t *testing.T) {
assert.True(t, isBlank(" \t"))
}
func TestIsBlankNo(t *testing.T) {
assert.False(t, isBlank(" h "))
}

24
wire/adb_error.go Normal file
View file

@ -0,0 +1,24 @@
package wire
import (
"fmt"
)
type AdbError struct {
Request []byte
ServerMsg string
}
var _ error = &AdbError{}
func (e *AdbError) Error() string {
if e.Request == nil {
return fmt.Sprintf("server error: %s", e.ServerMsg)
} else {
return fmt.Sprintf("server error for request '%s': %s", e.Request, e.ServerMsg)
}
}
func incompleteMessage(description string, actual int, expected int) error {
return fmt.Errorf("incomplete %s: read %d bytes, expecting %d", description, actual, expected)
}

View file

@ -1,13 +1,18 @@
/*
The wire package implements the low-level part of the client/server wire protocol.
package wire
The protocol spec can be found at
https://android.googlesource.com/platform/system/core/+/master/adb/OVERVIEW.TXT.
const (
// The official implementation of adb imposes an undocumented 255-byte limit
// on messages.
MaxMessageLength = 255
)
/*
Conn is a normal connection to an adb server.
For most cases, usage looks something like:
conn := wire.Dial()
conn.SendMessage(data)
conn.ReadStatus() == "OKAY" || "FAIL"
conn.ReadStatus() == StatusSuccess || StatusFailure
conn.ReadMessage()
conn.Close()
@ -18,52 +23,44 @@ it returns an io.EOF error.
For most commands, the server will close the connection after sending the response.
You should still always call Close() when you're done with the connection.
*/
package wire
import (
"fmt"
"io"
"net"
)
const (
// Default port the adb server listens on.
AdbPort = 5037
// The official implementation of adb imposes an undocumented 255-byte limit
// on messages.
MaxMessageLength = 255
)
// Conn is a connection to an adb server.
type Conn struct {
Scanner
Sender
io.Closer
closer func() error
}
// Dial connects to the adb server on the default port, AdbPort.
func Dial() (*Conn, error) {
return DialPort(AdbPort)
func NewConn(scanner Scanner, sender Sender, closer func() error) *Conn {
return &Conn{scanner, sender, closer}
}
// Dial connects to the adb server on port.
func DialPort(port int) (*Conn, error) {
return DialAddr(fmt.Sprintf("localhost:%d", port))
// Close closes the underlying connection.
func (c *Conn) Close() error {
if c.closer != nil {
return c.closer()
}
return nil
}
// Dial connects to the adb server at address.
func DialAddr(address string) (*Conn, error) {
netConn, err := net.Dial("tcp", address)
if err != nil {
// NewSyncConn returns connection that can operate in sync mode.
// The connection must already have been switched (by sending the sync command
// to a specific device), or the return connection will return an error.
func (c *Conn) NewSyncConn() *SyncConn {
return &SyncConn{
SyncScanner: c.Scanner.NewSyncScanner(),
SyncSender: c.Sender.NewSyncSender(),
}
}
// RoundTripSingleResponse sends a message to the server, and reads a single
// message response. If the reponse has a failure status code, returns it as an error.
func (conn *Conn) RoundTripSingleResponse(req []byte) (resp []byte, err error) {
if err = conn.SendMessage(req); err != nil {
return nil, err
}
return &Conn{
Scanner: NewScanner(netConn),
Sender: NewSender(netConn),
Closer: netConn,
}, nil
}
if err = ReadStatusFailureAsError(conn, req); err != nil {
return nil, err
}
var _ io.Closer = &Conn{}
return conn.ReadMessage()
}

70
wire/dialer.go Normal file
View file

@ -0,0 +1,70 @@
package wire
import (
"fmt"
"net"
"runtime"
)
const (
// Default port the adb server listens on.
AdbPort = 5037
)
/*
Dialer knows how to create connections to an adb server.
*/
type Dialer interface {
Dial() (*Conn, error)
}
type netDialer struct {
Host string
Port int
}
func NewDialer(host string, port int) Dialer {
return &netDialer{host, port}
}
// Dial connects to the adb server on the host and port set on the netDialer.
// The zero-value will connect to the default, localhost:5037.
func (d *netDialer) Dial() (*Conn, error) {
host := d.Host
if host == "" {
host = "localhost"
}
port := d.Port
if port == 0 {
port = AdbPort
}
netConn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", host, port))
if err != nil {
return nil, err
}
conn := &Conn{
Scanner: NewScanner(netConn),
Sender: NewSender(netConn),
closer: netConn.Close,
}
// Prevent leaking the network connection, not sure if TCPConn does this itself.
runtime.SetFinalizer(netConn, func(conn *net.TCPConn) {
conn.Close()
})
return conn, nil
}
func RoundTripSingleResponse(d Dialer, req string) ([]byte, error) {
conn, err := d.Dial()
if err != nil {
return nil, err
}
defer conn.Close()
return conn.RoundTripSingleResponse([]byte(req))
}

13
wire/doc.go Normal file
View file

@ -0,0 +1,13 @@
/*
Package wire implements the low-level part of the client/server wire protocol.
It also implements the "sync" wire format for file transfers.
This package is not intended to be used directly. goadb.HostClient and goadb.DeviceClient
use it to abstract away the bit-twiddling details of the protocol. You should only ever
need to work with the goadb package. Also, this package's API may change more frequently
than goadb's.
The protocol spec can be found at
https://android.googlesource.com/platform/system/core/+/master/adb/OVERVIEW.TXT.
*/
package wire

View file

@ -3,6 +3,7 @@ package wire
import (
"fmt"
"io"
"io/ioutil"
"strconv"
)
@ -27,6 +28,8 @@ See Conn for more details.
type Scanner interface {
ReadStatus() (StatusCode, error)
ReadMessage() ([]byte, error)
ReadUntilEof() ([]byte, error)
NewSyncScanner() SyncScanner
}
type realScanner struct {
@ -73,6 +76,38 @@ func (s *realScanner) ReadMessage() ([]byte, error) {
return data, nil
}
func (s *realScanner) ReadUntilEof() ([]byte, error) {
return ioutil.ReadAll(s.reader)
}
func (s *realScanner) NewSyncScanner() SyncScanner {
return NewSyncScanner(s.reader)
}
// Reads the status, and if failure, reads the message and returns it as an error.
// If the status is success, doesn't read the message.
// req is just used to populate the AdbError, and can be nil.
func ReadStatusFailureAsError(s Scanner, req []byte) error {
status, err := s.ReadStatus()
if err != nil {
return err
}
if !status.IsSuccess() {
msg, err := s.ReadMessage()
if err != nil {
return err
}
return &AdbError{
Request: req,
ServerMsg: string(msg),
}
}
return nil
}
func (s *realScanner) readLength() (int, error) {
lengthHex := make([]byte, 4)
n, err := io.ReadFull(s.reader, lengthHex)
@ -95,8 +130,4 @@ func (s *realScanner) readLength() (int, error) {
return int(length), nil
}
func incompleteMessage(description string, actual int, expected int) error {
return fmt.Errorf("incomplete %s: read %d bytes, expecting %d", description, actual, expected)
}
var _ Scanner = &realScanner{}

View file

@ -8,6 +8,7 @@ import (
// Sender sends messages to the server.
type Sender interface {
SendMessage(msg []byte) error
NewSyncSender() SyncSender
}
type realSender struct {
@ -31,15 +32,8 @@ func (s *realSender) SendMessage(msg []byte) error {
return writeFully(s.writer, []byte(lengthAndMsg))
}
func writeFully(w io.Writer, data []byte) error {
for len(data) > 0 {
n, err := w.Write(data)
if err != nil {
return err
}
data = data[n:]
}
return nil
func (s *realSender) NewSyncSender() SyncSender {
return NewSyncSender(s.writer)
}
var _ Sender = &realSender{}

198
wire/sync.go Normal file
View file

@ -0,0 +1,198 @@
// TODO(z): Write SyncSender.SendBytes().
package wire
import (
"encoding/binary"
"fmt"
"io"
"os"
"time"
)
const (
// Chunks cannot be longer than 64k.
MaxChunkSize = 64 * 1024
)
/*
SyncConn is a connection to the adb server in sync mode.
Assumes the connection has been put into sync mode (by sending "sync" in transport mode).
The adb sync protocol is defined at
https://android.googlesource.com/platform/system/core/+/master/adb/SYNC.TXT.
Unlike the normal adb protocol (implemented in Conn), the sync protocol is binary.
Lengths are binary-encoded (little-endian) instead of hex.
Notes on Encoding
Length headers and other integers are encoded in little-endian, with 32 bits.
File mode seems to be encoded as POSIX file mode.
Modification time seems to be the Unix timestamp format, i.e. seconds since Epoch UTC.
*/
type SyncConn struct {
SyncScanner
SyncSender
}
func (c *SyncConn) Close() error {
return c.SyncScanner.Close()
}
type SyncScanner interface {
// ReadOctetString reads a 4-byte string.
ReadOctetString() (string, error)
ReadInt32() (int32, error)
ReadFileMode() (os.FileMode, error)
ReadTime() (time.Time, error)
// Reads an octet length, followed by length bytes.
ReadString() (string, error)
// Reads an octet length, and returns a reader that will read length
// bytes (see io.LimitReader). The returned reader should be fully
// read before reading anything off the Scanner again.
ReadBytes() (io.Reader, error)
// Closes the underlying reader.
Close() error
}
type SyncSender interface {
// SendOctetString sends a 4-byte string.
SendOctetString(string) error
SendInt32(int32) error
SendFileMode(os.FileMode) error
SendTime(time.Time) error
// Sends len(bytes) as an octet, followed by bytes.
SendString(str string) error
}
type realSyncScanner struct {
io.Reader
}
type realSyncSender struct {
io.Writer
}
func NewSyncScanner(r io.Reader) SyncScanner {
return &realSyncScanner{r}
}
func NewSyncSender(w io.Writer) SyncSender {
return &realSyncSender{w}
}
func RequireOctetString(s SyncScanner, expected string) error {
actual, err := s.ReadOctetString()
if err != nil {
return fmt.Errorf("expected to read '%s', got err: %v", expected, err)
}
if actual != expected {
return fmt.Errorf("expected to read '%s', got '%s'", expected, actual)
}
return nil
}
func (s *realSyncScanner) ReadOctetString() (string, error) {
octet := make([]byte, 4)
n, err := io.ReadFull(s.Reader, octet)
if err != nil && err != io.ErrUnexpectedEOF {
return "", err
} else if err == io.ErrUnexpectedEOF {
return "", incompleteMessage("octet", n, 4)
}
return string(octet), nil
}
func (s *realSyncSender) SendOctetString(str string) error {
if len(str) != 4 {
return fmt.Errorf("octet string must be exactly 4 bytes: '%s'", str)
}
return writeFully(s.Writer, []byte(str))
}
func (s *realSyncScanner) ReadInt32() (int32, error) {
var value int32
err := binary.Read(s.Reader, binary.LittleEndian, &value)
return value, err
}
func (s *realSyncSender) SendInt32(val int32) error {
return binary.Write(s.Writer, binary.LittleEndian, val)
}
func (s *realSyncScanner) ReadFileMode() (os.FileMode, error) {
var value uint32
err := binary.Read(s.Reader, binary.LittleEndian, &value)
return os.FileMode(value), err
}
func (s *realSyncSender) SendFileMode(mode os.FileMode) error {
return binary.Write(s.Writer, binary.LittleEndian, mode)
}
func (s *realSyncScanner) ReadTime() (time.Time, error) {
seconds, err := s.ReadInt32()
if err != nil {
return time.Time{}, err
}
return time.Unix(int64(seconds), 0).UTC(), nil
}
func (s *realSyncSender) SendTime(t time.Time) error {
return s.SendInt32(int32(t.Unix()))
}
func (s *realSyncScanner) ReadString() (string, error) {
length, err := s.ReadInt32()
if err != nil {
return "", err
}
bytes := make([]byte, length)
n, err := io.ReadFull(s.Reader, bytes)
if err != nil && err != io.ErrUnexpectedEOF {
return "", err
} else if err == io.ErrUnexpectedEOF {
return "", incompleteMessage("bytes", n, int(length))
}
return string(bytes), nil
}
func (s *realSyncSender) SendString(str string) error {
length := len(str)
if length > MaxChunkSize {
// This limit might not apply to filenames, but it's big enough
// that I don't think it will be a problem.
return fmt.Errorf("str must be <= %d in length", MaxChunkSize)
}
if err := s.SendInt32(int32(length)); err != nil {
return err
}
return writeFully(s.Writer, []byte(str))
}
func (s *realSyncScanner) ReadBytes() (io.Reader, error) {
length, err := s.ReadInt32()
if err != nil {
return nil, err
}
return io.LimitReader(s.Reader, int64(length)), nil
}
func (s *realSyncScanner) Close() error {
if closer, ok := s.Reader.(io.Closer); ok {
return closer.Close()
}
return nil
}

87
wire/sync_test.go Normal file
View file

@ -0,0 +1,87 @@
package wire
import (
"bytes"
"io/ioutil"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
var (
someTime = time.Date(2015, 04, 12, 20, 7, 51, 0, time.UTC)
// The little-endian encoding of someTime.Unix()
someTimeEncoded = []byte{151, 208, 42, 85}
)
func TestSyncReadOctetString(t *testing.T) {
s := NewSyncScanner(strings.NewReader("helo"))
str, err := s.ReadOctetString()
assert.NoError(t, err)
assert.Equal(t, "helo", str)
}
func TestSyncSendOctetString(t *testing.T) {
var buf bytes.Buffer
s := NewSyncSender(&buf)
err := s.SendOctetString("helo")
assert.NoError(t, err)
assert.Equal(t, "helo", buf.String())
}
func TestSyncSendOctetStringTooLong(t *testing.T) {
var buf bytes.Buffer
s := NewSyncSender(&buf)
err := s.SendOctetString("hello")
assert.EqualError(t, err, "octet string must be exactly 4 bytes: 'hello'")
}
func TestSyncReadTime(t *testing.T) {
s := NewSyncScanner(bytes.NewReader(someTimeEncoded))
decoded, err := s.ReadTime()
assert.NoError(t, err)
assert.Equal(t, someTime, decoded)
}
func TestSyncSendTime(t *testing.T) {
var buf bytes.Buffer
s := NewSyncSender(&buf)
err := s.SendTime(someTime)
assert.NoError(t, err)
assert.Equal(t, someTimeEncoded, buf.Bytes())
}
func TestSyncReadString(t *testing.T) {
s := NewSyncScanner(strings.NewReader("\005\000\000\000hello"))
str, err := s.ReadString()
assert.NoError(t, err)
assert.Equal(t, "hello", str)
}
func TestSyncReadStringTooShort(t *testing.T) {
s := NewSyncScanner(strings.NewReader("\005\000\000\000h"))
_, err := s.ReadString()
assert.EqualError(t, err, "incomplete bytes: read 1 bytes, expecting 5")
}
func TestSyncSendString(t *testing.T) {
var buf bytes.Buffer
s := NewSyncSender(&buf)
err := s.SendString("hello")
assert.NoError(t, err)
assert.Equal(t, "\005\000\000\000hello", buf.String())
}
func TestSyncReadBytes(t *testing.T) {
s := NewSyncScanner(strings.NewReader("\005\000\000\000helloworld"))
reader, err := s.ReadBytes()
assert.NoError(t, err)
assert.NotNil(t, reader)
str, err := ioutil.ReadAll(reader)
assert.NoError(t, err)
assert.Equal(t, "hello", string(str))
}

16
wire/util.go Normal file
View file

@ -0,0 +1,16 @@
package wire
import "io"
// writeFully writes all of data to w.
// Inverse of io.ReadFully().
func writeFully(w io.Writer, data []byte) error {
for len(data) > 0 {
n, err := w.Write(data)
if err != nil {
return err
}
data = data[n:]
}
return nil
}