From 4b9891533a795fada0c0fce684e3fad617f5f64b Mon Sep 17 00:00:00 2001 From: Zach Klippenstein Date: Sun, 12 Apr 2015 13:34:20 -0700 Subject: [PATCH] Started host and local service clients. Milestone: the demo app prints /proc/loadavg from the device. --- cmd/demo/demo.go | 91 ++++++++++- cmd/raw-adb/raw-adb.go | 3 +- device_client.go | 208 +++++++++++++++++++++++++ device_client_test.go | 68 ++++++++ device_descriptor.go | 80 ++++++++++ devices.go => device_info.go | 33 ++-- devices_test.go => device_info_test.go | 24 +-- devicedescriptortype_string.go | 16 ++ dir_entries.go | 98 ++++++++++++ doc.go | 12 ++ host_client.go | 109 ++++++------- host_client_test.go | 41 +++-- nil_safe_dialer.go | 15 ++ sync_client.go | 92 +++++++++++ sync_client_test.go | 2 + sync_file_reader.go | 69 ++++++++ sync_file_reader_test.go | 82 ++++++++++ util.go | 18 +++ util_test.go | 27 ++++ wire/adb_error.go | 24 +++ wire/conn.go | 79 +++++----- wire/dialer.go | 70 +++++++++ wire/doc.go | 13 ++ wire/scanner.go | 39 ++++- wire/sender.go | 12 +- wire/sync.go | 198 +++++++++++++++++++++++ wire/sync_test.go | 87 +++++++++++ wire/util.go | 16 ++ 28 files changed, 1470 insertions(+), 156 deletions(-) create mode 100644 device_client.go create mode 100644 device_client_test.go create mode 100644 device_descriptor.go rename devices.go => device_info.go (73%) rename devices_test.go => device_info_test.go (73%) create mode 100644 devicedescriptortype_string.go create mode 100644 dir_entries.go create mode 100644 doc.go create mode 100644 nil_safe_dialer.go create mode 100644 sync_client.go create mode 100644 sync_client_test.go create mode 100644 sync_file_reader.go create mode 100644 sync_file_reader_test.go create mode 100644 util.go create mode 100644 util_test.go create mode 100644 wire/adb_error.go create mode 100644 wire/dialer.go create mode 100644 wire/doc.go create mode 100644 wire/sync.go create mode 100644 wire/sync_test.go create mode 100644 wire/util.go diff --git a/cmd/demo/demo.go b/cmd/demo/demo.go index cfd6647..5ec45ac 100644 --- a/cmd/demo/demo.go +++ b/cmd/demo/demo.go @@ -1,15 +1,22 @@ +// An app demonstrating most of the library's features. package main import ( + "flag" "fmt" + "io/ioutil" "log" adb "github.com/zach-klippenstein/goadb" "github.com/zach-klippenstein/goadb/wire" ) +var port = flag.Int("p", wire.AdbPort, "") + func main() { - client := &adb.HostClient{wire.Dial} + flag.Parse() + + client := adb.NewHostClientDialer(wire.NewDialer("", *port)) fmt.Println("Starting server…") client.StartServer() @@ -28,6 +35,84 @@ func main() { fmt.Printf("\t%+v\n", *device) } - fmt.Println("Killing server…") - client.KillServer() + PrintDeviceInfoAndError(client.GetAnyDevice()) + PrintDeviceInfoAndError(client.GetLocalDevice()) + PrintDeviceInfoAndError(client.GetUsbDevice()) + + serials, err := client.ListDeviceSerials() + if err != nil { + log.Fatal(err) + } + for _, serial := range serials { + PrintDeviceInfoAndError(client.GetDeviceWithSerial(serial)) + } + + //fmt.Println("Killing server…") + //client.KillServer() +} + +func PrintDeviceInfoAndError(device *adb.DeviceClient) { + if err := PrintDeviceInfo(device); err != nil { + log.Println(err) + } +} + +func PrintDeviceInfo(device *adb.DeviceClient) error { + serialNo, err := device.GetSerial() + if err != nil { + return err + } + devPath, err := device.GetDevicePath() + if err != nil { + return err + } + state, err := device.GetState() + if err != nil { + return err + } + + fmt.Println(device) + fmt.Printf("\tserial no: %s\n", serialNo) + fmt.Printf("\tdevPath: %s\n", devPath) + fmt.Printf("\tstate: %s\n", state) + + cmdOutput, err := device.RunCommand("pwd") + if err != nil { + fmt.Println("\terror running command:", err) + } + fmt.Printf("\tcmd output: %s\n", cmdOutput) + + stat, err := device.Stat("/sdcard") + if err != nil { + fmt.Println("\terror stating /sdcard:", err) + } + fmt.Printf("\tstat \"/sdcard\": %+v\n", stat) + + fmt.Println("\tfiles in \"/\":") + entries, err := device.ListDirEntries("/") + if err != nil { + fmt.Println("\terror listing files:", err) + } else { + for entries.Next() { + fmt.Printf("\t%+v\n", *entries.Entry()) + } + if entries.Err() != nil { + fmt.Println("\terror listing files:", err) + } + } + + fmt.Print("\tload avg: ") + loadavgReader, err := device.OpenRead("/proc/loadavg") + if err != nil { + fmt.Println("\terror opening file:", err) + } else { + loadAvg, err := ioutil.ReadAll(loadavgReader) + if err != nil { + fmt.Println("\terror reading file:", err) + } else { + fmt.Println(string(loadAvg)) + } + } + + return nil } diff --git a/cmd/raw-adb/raw-adb.go b/cmd/raw-adb/raw-adb.go index e9a06af..8d00e0c 100644 --- a/cmd/raw-adb/raw-adb.go +++ b/cmd/raw-adb/raw-adb.go @@ -1,3 +1,4 @@ +// A simple tool for sending raw messages to an adb server. package main import ( @@ -47,7 +48,7 @@ func readLine() string { } func doCommand(cmd string) error { - conn, err := wire.DialPort(*port) + conn, err := wire.NewDialer("", *port).Dial() if err != nil { log.Fatal(err) } diff --git a/device_client.go b/device_client.go new file mode 100644 index 0000000..6bba1e6 --- /dev/null +++ b/device_client.go @@ -0,0 +1,208 @@ +package goadb + +import ( + "fmt" + "io" + "strings" + + "github.com/zach-klippenstein/goadb/wire" +) + +/* +DeviceClient communicates with a specific Android device. +*/ +type DeviceClient struct { + dialer nilSafeDialer + descriptor *DeviceDescriptor +} + +func (c *DeviceClient) String() string { + return c.descriptor.String() +} + +// get-product is documented, but not implemented in the server. +// TODO(z): Make getProduct exported if get-product is ever implemented in adb. +func (c *DeviceClient) getProduct() (string, error) { + return c.getAttribute("get-product") +} + +func (c *DeviceClient) GetSerial() (string, error) { + return c.getAttribute("get-serialno") +} + +func (c *DeviceClient) GetDevicePath() (string, error) { + return c.getAttribute("get-devpath") +} + +func (c *DeviceClient) GetState() (string, error) { + return c.getAttribute("get-state") +} + +/* +RunCommand runs the specified commands on a shell on the device. + +From the Android docs: + Run 'command arg1 arg2 ...' in a shell on the device, and return + its output and error streams. Note that arguments must be separated + by spaces. If an argument contains a space, it must be quoted with + double-quotes. Arguments cannot contain double quotes or things + will go very wrong. + + Note that this is the non-interactive version of "adb shell" +Source: https://android.googlesource.com/platform/system/core/+/master/adb/SERVICES.TXT + +This method quotes the arguments for you, and will return an error if any of them +contain double quotes. +*/ +func (c *DeviceClient) RunCommand(cmd string, args ...string) (string, error) { + cmd, err := prepareCommandLine(cmd, args...) + if err != nil { + return "", err + } + + conn, err := c.dialDevice() + if err != nil { + return "", err + } + defer conn.Close() + + req := fmt.Sprintf("shell:%s", cmd) + + // Shell responses are special, they don't include a length header. + // We read until the stream is closed. + // So, we can't use conn.RoundTripSingleResponse. + if err = conn.SendMessage([]byte(req)); err != nil { + return "", err + } + if err = wire.ReadStatusFailureAsError(conn, []byte(req)); err != nil { + return "", err + } + + resp, err := conn.ReadUntilEof() + if err != nil { + return "", err + } + + return string(resp), nil +} + +/* +Remount, from the docs, + Ask adbd to remount the device's filesystem in read-write mode, + instead of read-only. This is usually necessary before performing + an "adb sync" or "adb push" request. + This request may not succeed on certain builds which do not allow + that. +Source: https://android.googlesource.com/platform/system/core/+/master/adb/SERVICES.TXT +*/ +func (c *DeviceClient) Remount() (string, error) { + conn, err := c.dialDevice() + if err != nil { + return "", err + } + defer conn.Close() + + resp, err := conn.RoundTripSingleResponse([]byte("remount")) + return string(resp), err +} + +func (c *DeviceClient) ListDirEntries(path string) (*DirEntries, error) { + conn, err := c.getSyncConn() + if err != nil { + return nil, err + } + + return listDirEntries(conn, path) +} + +func (c *DeviceClient) Stat(path string) (*DirEntry, error) { + conn, err := c.getSyncConn() + if err != nil { + return nil, err + } + + return stat(conn, path) +} + +func (c *DeviceClient) OpenRead(path string) (io.ReadCloser, error) { + conn, err := c.getSyncConn() + if err != nil { + return nil, err + } + + return receiveFile(conn, path) +} + +// getAttribute returns the first message returned by the server by running +// :, where host-prefix is determined from the DeviceDescriptor. +func (c *DeviceClient) getAttribute(attr string) (string, error) { + resp, err := wire.RoundTripSingleResponse(c.dialer, + fmt.Sprintf("%s:%s", c.descriptor.getHostPrefix(), attr)) + if err != nil { + return "", err + } + return string(resp), nil +} + +// dialDevice switches the connection to communicate directly with the device +// by requesting the transport defined by the DeviceDescriptor. +func (c *DeviceClient) dialDevice() (*wire.Conn, error) { + conn, err := c.dialer.Dial() + if err != nil { + return nil, err + } + + req := fmt.Sprintf("host:%s", c.descriptor.getTransportDescriptor()) + if err = wire.SendMessageString(conn, req); err != nil { + conn.Close() + return nil, err + } + + if err = wire.ReadStatusFailureAsError(conn, []byte(req)); err != nil { + conn.Close() + return nil, err + } + + return conn, nil +} + +func (c *DeviceClient) getSyncConn() (*wire.SyncConn, error) { + conn, err := c.dialDevice() + if err != nil { + return nil, err + } + + // Switch the connection to sync mode. + if err := wire.SendMessageString(conn, "sync:"); err != nil { + return nil, err + } + if err := wire.ReadStatusFailureAsError(conn, []byte("sync")); err != nil { + return nil, err + } + + return conn.NewSyncConn(), nil +} + +// prepareCommandLine validates the command and argument strings, quotes +// arguments if required, and joins them into a valid adb command string. +func prepareCommandLine(cmd string, args ...string) (string, error) { + if isBlank(cmd) { + return "", fmt.Errorf("command cannot be empty") + } + + for i, arg := range args { + if strings.ContainsRune(arg, '"') { + return "", fmt.Errorf("arg at index %d contains an invalid double quote: %s", i, arg) + } + if containsWhitespace(arg) { + args[i] = fmt.Sprintf("\"%s\"", arg) + } + } + + // Prepend the comand to the args array. + if len(args) > 0 { + cmd = fmt.Sprintf("%s %s", cmd, strings.Join(args, " ")) + } + + return cmd, nil +} diff --git a/device_client_test.go b/device_client_test.go new file mode 100644 index 0000000..37131d3 --- /dev/null +++ b/device_client_test.go @@ -0,0 +1,68 @@ +package goadb + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/zach-klippenstein/goadb/wire" +) + +func TestGetAttribute(t *testing.T) { + s := &MockServer{ + Status: wire.StatusSuccess, + Messages: []string{"value"}, + } + client := &DeviceClient{nilSafeDialer{s}, deviceWithSerial("serial")} + + v, err := client.getAttribute("attr") + assert.Equal(t, "host-serial:serial:attr", s.Requests[0]) + assert.NoError(t, err) + assert.Equal(t, "value", v) +} + +func TestRunCommandNoArgs(t *testing.T) { + s := &MockServer{ + Status: wire.StatusSuccess, + Messages: []string{"output"}, + } + client := &DeviceClient{nilSafeDialer{s}, anyDevice()} + + v, err := client.RunCommand("cmd") + assert.Equal(t, "host:transport-any", s.Requests[0]) + assert.Equal(t, "shell:cmd", s.Requests[1]) + assert.NoError(t, err) + assert.Equal(t, "output", v) +} + +func TestPrepareCommandLineNoArgs(t *testing.T) { + result, err := prepareCommandLine("cmd") + assert.NoError(t, err) + assert.Equal(t, "cmd", result) +} + +func TestPrepareCommandLineEmptyCommand(t *testing.T) { + _, err := prepareCommandLine("") + assert.EqualError(t, err, "command cannot be empty") +} + +func TestPrepareCommandLineBlankCommand(t *testing.T) { + _, err := prepareCommandLine(" ") + assert.EqualError(t, err, "command cannot be empty") +} + +func TestPrepareCommandLineCleanArgs(t *testing.T) { + result, err := prepareCommandLine("cmd", "arg1", "arg2") + assert.NoError(t, err) + assert.Equal(t, "cmd arg1 arg2", result) +} + +func TestPrepareCommandLineArgWithWhitespaceQuotes(t *testing.T) { + result, err := prepareCommandLine("cmd", "arg with spaces") + assert.NoError(t, err) + assert.Equal(t, "cmd \"arg with spaces\"", result) +} + +func TestPrepareCommandLineArgWithDoubleQuoteFails(t *testing.T) { + _, err := prepareCommandLine("cmd", "quoted\"arg") + assert.EqualError(t, err, "arg at index 0 contains an invalid double quote: quoted\"arg") +} diff --git a/device_descriptor.go b/device_descriptor.go new file mode 100644 index 0000000..739f7d8 --- /dev/null +++ b/device_descriptor.go @@ -0,0 +1,80 @@ +package goadb + +import "fmt" + +//go:generate stringer -type=deviceDescriptorType +type deviceDescriptorType int + +const ( + // host:transport-any and host: + DeviceAny deviceDescriptorType = iota + // host:transport: and host-serial:: + DeviceSerial + // host:transport-usb and host-usb: + DeviceUsb + // host:transport-local and host-local: + DeviceLocal +) + +type DeviceDescriptor struct { + descriptorType deviceDescriptorType + + // Only used if Type is DeviceSerial. + serial string +} + +func anyDevice() *DeviceDescriptor { + return &DeviceDescriptor{descriptorType: DeviceAny} +} + +func anyUsbDevice() *DeviceDescriptor { + return &DeviceDescriptor{descriptorType: DeviceUsb} +} + +func anyLocalDevice() *DeviceDescriptor { + return &DeviceDescriptor{descriptorType: DeviceLocal} +} + +func deviceWithSerial(serial string) *DeviceDescriptor { + return &DeviceDescriptor{ + descriptorType: DeviceSerial, + serial: serial, + } +} + +func (d *DeviceDescriptor) String() string { + if d.descriptorType == DeviceSerial { + return fmt.Sprintf("%s[%s]", d.descriptorType, d.serial) + } + return d.descriptorType.String() +} + +func (d *DeviceDescriptor) getHostPrefix() string { + switch d.descriptorType { + case DeviceAny: + return "host" + case DeviceUsb: + return "host-usb" + case DeviceLocal: + return "host-local" + case DeviceSerial: + return fmt.Sprintf("host-serial:%s", d.serial) + default: + panic(fmt.Sprintf("invalid DeviceDescriptorType: %v", d.descriptorType)) + } +} + +func (d *DeviceDescriptor) getTransportDescriptor() string { + switch d.descriptorType { + case DeviceAny: + return "transport-any" + case DeviceUsb: + return "transport-usb" + case DeviceLocal: + return "transport-local" + case DeviceSerial: + return fmt.Sprintf("transport:%s", d.serial) + default: + panic(fmt.Sprintf("invalid DeviceDescriptorType: %v", d.descriptorType)) + } +} diff --git a/devices.go b/device_info.go similarity index 73% rename from devices.go rename to device_info.go index 6521c57..c1af642 100644 --- a/devices.go +++ b/device_info.go @@ -6,41 +6,40 @@ import ( "strings" ) -// Device represents a connected Android device. -type Device struct { +type DeviceInfo struct { // Always set. Serial string // Product, device, and model are not set in the short form. - Product string - Model string - Device string + Product string + Model string + DeviceInfo string // Only set for devices connected via USB. Usb string } // IsUsb returns true if the device is connected via USB. -func (d *Device) IsUsb() bool { +func (d *DeviceInfo) IsUsb() bool { return d.Usb != "" } -func newDevice(serial string, attrs map[string]string) (*Device, error) { +func newDevice(serial string, attrs map[string]string) (*DeviceInfo, error) { if serial == "" { return nil, fmt.Errorf("device serial cannot be blank") } - return &Device{ - Serial: serial, - Product: attrs["product"], - Model: attrs["model"], - Device: attrs["device"], - Usb: attrs["usb"], + return &DeviceInfo{ + Serial: serial, + Product: attrs["product"], + Model: attrs["model"], + DeviceInfo: attrs["device"], + Usb: attrs["usb"], }, nil } -func parseDeviceList(list string, lineParseFunc func(string) (*Device, error)) ([]*Device, error) { - var devices []*Device +func parseDeviceList(list string, lineParseFunc func(string) (*DeviceInfo, error)) ([]*DeviceInfo, error) { + var devices []*DeviceInfo scanner := bufio.NewScanner(strings.NewReader(list)) for scanner.Scan() { @@ -57,7 +56,7 @@ func parseDeviceList(list string, lineParseFunc func(string) (*Device, error)) ( return devices, nil } -func parseDeviceShort(line string) (*Device, error) { +func parseDeviceShort(line string) (*DeviceInfo, error) { fields := strings.Fields(line) if len(fields) != 2 { return nil, fmt.Errorf("malformed device line, expected 2 fields but found %d", len(fields)) @@ -66,7 +65,7 @@ func parseDeviceShort(line string) (*Device, error) { return newDevice(fields[0], map[string]string{}) } -func parseDeviceLong(line string) (*Device, error) { +func parseDeviceLong(line string) (*DeviceInfo, error) { fields := strings.Fields(line) if len(fields) < 5 { return nil, fmt.Errorf("malformed device line, expected at least 5 fields but found %d", len(fields)) diff --git a/devices_test.go b/device_info_test.go similarity index 73% rename from devices_test.go rename to device_info_test.go index 2920efe..3fc3189 100644 --- a/devices_test.go +++ b/device_info_test.go @@ -19,27 +19,27 @@ func ParseDeviceList(t *testing.T) { func TestParseDeviceShort(t *testing.T) { dev, err := parseDeviceShort("192.168.56.101:5555 device\n") assert.NoError(t, err) - assert.Equal(t, &Device{ + assert.Equal(t, &DeviceInfo{ Serial: "192.168.56.101:5555"}, dev) } func TestParseDeviceLong(t *testing.T) { dev, err := parseDeviceLong("SERIAL device product:PRODUCT model:MODEL device:DEVICE\n") assert.NoError(t, err) - assert.Equal(t, &Device{ - Serial: "SERIAL", - Product: "PRODUCT", - Model: "MODEL", - Device: "DEVICE"}, dev) + assert.Equal(t, &DeviceInfo{ + Serial: "SERIAL", + Product: "PRODUCT", + Model: "MODEL", + DeviceInfo: "DEVICE"}, dev) } func TestParseDeviceLongUsb(t *testing.T) { dev, err := parseDeviceLong("SERIAL device usb:1234 product:PRODUCT model:MODEL device:DEVICE \n") assert.NoError(t, err) - assert.Equal(t, &Device{ - Serial: "SERIAL", - Product: "PRODUCT", - Model: "MODEL", - Device: "DEVICE", - Usb: "1234"}, dev) + assert.Equal(t, &DeviceInfo{ + Serial: "SERIAL", + Product: "PRODUCT", + Model: "MODEL", + DeviceInfo: "DEVICE", + Usb: "1234"}, dev) } diff --git a/devicedescriptortype_string.go b/devicedescriptortype_string.go new file mode 100644 index 0000000..231fbb0 --- /dev/null +++ b/devicedescriptortype_string.go @@ -0,0 +1,16 @@ +// generated by stringer -type=deviceDescriptorType; DO NOT EDIT + +package goadb + +import "fmt" + +const _deviceDescriptorType_name = "DeviceAnyDeviceSerialDeviceUsbDeviceLocal" + +var _deviceDescriptorType_index = [...]uint8{0, 9, 21, 30, 41} + +func (i deviceDescriptorType) String() string { + if i < 0 || i+1 >= deviceDescriptorType(len(_deviceDescriptorType_index)) { + return fmt.Sprintf("deviceDescriptorType(%d)", i) + } + return _deviceDescriptorType_name[_deviceDescriptorType_index[i]:_deviceDescriptorType_index[i+1]] +} diff --git a/dir_entries.go b/dir_entries.go new file mode 100644 index 0000000..6603469 --- /dev/null +++ b/dir_entries.go @@ -0,0 +1,98 @@ +package goadb + +import ( + "fmt" + + "github.com/zach-klippenstein/goadb/wire" +) + +// DirEntries iterates over directory entries. +type DirEntries struct { + scanner wire.SyncScanner + + // Called when finished iterating (successfully or not). + doneHandler func() + + currentEntry *DirEntry + err error +} + +func (entries *DirEntries) Next() bool { + if entries.err != nil { + return false + } + + entry, done, err := readNextDirListEntry(entries.scanner) + if err != nil { + entries.err = err + entries.onDone() + return false + } + + entries.currentEntry = entry + if done { + entries.onDone() + return false + } + + return true +} + +func (entries *DirEntries) Entry() *DirEntry { + return entries.currentEntry +} + +func (entries *DirEntries) Err() error { + return entries.err +} + +func (entries *DirEntries) onDone() { + if entries.doneHandler != nil { + entries.doneHandler() + } +} + +func readNextDirListEntry(s wire.SyncScanner) (entry *DirEntry, done bool, err error) { + id, err := s.ReadOctetString() + if err != nil { + return + } + + if id == "DONE" { + done = true + return + } else if id != "DENT" { + err = fmt.Errorf("expected dir entry ID 'DENT', but got '%s'", id) + return + } + + mode, err := s.ReadFileMode() + if err != nil { + err = fmt.Errorf("error reading file mode: %v", err) + return + } + size, err := s.ReadInt32() + if err != nil { + err = fmt.Errorf("error reading file size: %v", err) + return + } + mtime, err := s.ReadTime() + if err != nil { + err = fmt.Errorf("error reading file time: %v", err) + return + } + name, err := s.ReadString() + if err != nil { + err = fmt.Errorf("error reading file name: %v", err) + return + } + + done = false + entry = &DirEntry{ + Name: name, + Mode: mode, + Size: size, + ModifiedAt: mtime, + } + return +} diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..b989ba7 --- /dev/null +++ b/doc.go @@ -0,0 +1,12 @@ +/* +Package goadb is a Go interface to the Android Debug Bridge (adb). + +See cmd/demo/demo.go for an example of how to use this library. + +The client/server spec is defined at https://android.googlesource.com/platform/system/core/+/master/adb/OVERVIEW.TXT. + +WARNING This library is under heavy development, and its API is likely to change without notice. +*/ +package goadb + +// TODO(z): Write method-specific examples. diff --git a/host_client.go b/host_client.go index 98b9609..dbecccc 100644 --- a/host_client.go +++ b/host_client.go @@ -1,43 +1,40 @@ -/* -Package goadb is a Go interface to the Android Debug Bridge (adb). - -The client/server spec is defined at https://android.googlesource.com/platform/system/core/+/master/adb/OVERVIEW.TXT. - -WARNING This library is under heavy development, and its API is likely to change without notice. -*/ +// TODO(z): Implement TrackDevices. package goadb import ( - "fmt" "os/exec" "strconv" "github.com/zach-klippenstein/goadb/wire" ) -// Dialer is a function that knows how to create a connection to an adb server. -type Dialer func() (*wire.Conn, error) - /* -HostClient interacts with host services on the adb server. +HostClient communicates with host services on the adb server. Eg. - dialer := &HostClient{wire.Dial} - dialer.GetServerVersion() - -TODO make this a real example. - -TODO Finish implementing services. + client := NewHostClient() + client.StartServer() + client.ListDevices() + client.GetAnyDevice() // see DeviceClient See list of services at https://android.googlesource.com/platform/system/core/+/master/adb/SERVICES.TXT. */ +// TODO(z): Finish implementing host services. type HostClient struct { - Dialer + dialer nilSafeDialer +} + +func NewHostClient() *HostClient { + return NewHostClientDialer(nil) +} + +func NewHostClientDialer(d wire.Dialer) *HostClient { + return &HostClient{nilSafeDialer{d}} } // GetServerVersion asks the ADB server for its internal version number. func (c *HostClient) GetServerVersion() (int, error) { - resp, err := c.roundTripSingleResponse([]byte("host:version")) + resp, err := wire.RoundTripSingleResponse(c.dialer, "host:version") if err != nil { return 0, err } @@ -53,13 +50,13 @@ Corresponds to the command: adb kill-server */ func (c *HostClient) KillServer() error { - conn, err := c.Dialer() + conn, err := c.dialer.Dial() if err != nil { return err } defer conn.Close() - if err = conn.SendMessage([]byte("host:kill")); err != nil { + if err = wire.SendMessageString(conn, "host:kill"); err != nil { return err } @@ -84,7 +81,7 @@ Corresponds to the command: adb devices */ func (c *HostClient) ListDeviceSerials() ([]string, error) { - resp, err := c.roundTripSingleResponse([]byte("host:devices")) + resp, err := wire.RoundTripSingleResponse(c.dialer, "host:devices") if err != nil { return nil, err } @@ -107,8 +104,8 @@ ListDevices returns the list of connected devices. Corresponds to the command: adb devices -l */ -func (c *HostClient) ListDevices() ([]*Device, error) { - resp, err := c.roundTripSingleResponse([]byte("host:devices-l")) +func (c *HostClient) ListDevices() ([]*DeviceInfo, error) { + resp, err := wire.RoundTripSingleResponse(c.dialer, "host:devices-l") if err != nil { return nil, err } @@ -116,41 +113,33 @@ func (c *HostClient) ListDevices() ([]*Device, error) { return parseDeviceList(string(resp), parseDeviceLong) } -func (c *HostClient) roundTripSingleResponse(req []byte) (resp []byte, err error) { - conn, err := c.Dialer() - if err != nil { - return nil, err - } - defer conn.Close() - - if err = conn.SendMessage(req); err != nil { - return nil, err - } - - err = c.readStatusFailureAsError(conn) - if err != nil { - return nil, err - } - - return conn.ReadMessage() +func (c *HostClient) GetDevice(d *DeviceInfo) *DeviceClient { + return c.GetDeviceWithSerial(d.Serial) } -// Reads the status, and if failure, reads the message and returns it as an error. -// If the status is success, doesn't read the message. -func (c *HostClient) readStatusFailureAsError(conn *wire.Conn) error { - status, err := conn.ReadStatus() - if err != nil { - return err - } - - if !status.IsSuccess() { - msg, err := conn.ReadMessage() - if err != nil { - return err - } - - return fmt.Errorf("server error: %s", msg) - } - - return nil +// GetDeviceWithSerial returns a client for the device with the specified serial number. +// Will return a client even if there is no matching device connected. +func (c *HostClient) GetDeviceWithSerial(serial string) *DeviceClient { + return c.getDevice(deviceWithSerial(serial)) +} + +// GetAnyDevice returns a client for any one connected device. +func (c *HostClient) GetAnyDevice() *DeviceClient { + return c.getDevice(anyDevice()) +} + +// GetUsbDevice returns a client for the USB device. +// Will return a client even if there is no device connected. +func (c *HostClient) GetUsbDevice() *DeviceClient { + return c.getDevice(anyUsbDevice()) +} + +// GetLocalDevice returns a client for the local device. +// Will return a client even if there is no device connected. +func (c *HostClient) GetLocalDevice() *DeviceClient { + return c.getDevice(anyLocalDevice()) +} + +func (c *HostClient) getDevice(descriptor *DeviceDescriptor) *DeviceClient { + return &DeviceClient{c.dialer, descriptor} } diff --git a/host_client_test.go b/host_client_test.go index 754f0c2..23a3451 100644 --- a/host_client_test.go +++ b/host_client_test.go @@ -2,6 +2,7 @@ package goadb import ( "io" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -9,29 +10,34 @@ import ( ) func TestGetServerVersion(t *testing.T) { - client := &HostClient{mockDialer(&MockServer{ + s := &MockServer{ Status: wire.StatusSuccess, Messages: []string{"000a"}, - })} + } + client := NewHostClientDialer(s) v, err := client.GetServerVersion() + assert.Equal(t, "host:version", s.Requests[0]) assert.NoError(t, err) assert.Equal(t, 10, v) } -func mockDialer(s *MockServer) Dialer { - return func() (*wire.Conn, error) { - return &wire.Conn{s, s, s}, nil - } -} - type MockServer struct { - Status wire.StatusCode + Status wire.StatusCode + + // Messages are sent in order, each preceded by a length header. Messages []string + // Each request is appended to this slice. + Requests []string + nextMsgIndex int } +func (s *MockServer) Dial() (*wire.Conn, error) { + return wire.NewConn(s, s, s.Close), nil +} + func (s *MockServer) ReadStatus() (wire.StatusCode, error) { return s.Status, nil } @@ -45,7 +51,24 @@ func (s *MockServer) ReadMessage() ([]byte, error) { return []byte(s.Messages[s.nextMsgIndex-1]), nil } +func (s *MockServer) ReadUntilEof() ([]byte, error) { + var data []string + for ; s.nextMsgIndex < len(s.Messages); s.nextMsgIndex++ { + data = append(data, s.Messages[s.nextMsgIndex]) + } + return []byte(strings.Join(data, "")), nil +} + func (s *MockServer) SendMessage(msg []byte) error { + s.Requests = append(s.Requests, string(msg)) + return nil +} + +func (s *MockServer) NewSyncScanner() wire.SyncScanner { + return nil +} + +func (s *MockServer) NewSyncSender() wire.SyncSender { return nil } diff --git a/nil_safe_dialer.go b/nil_safe_dialer.go new file mode 100644 index 0000000..d0e7344 --- /dev/null +++ b/nil_safe_dialer.go @@ -0,0 +1,15 @@ +package goadb + +import "github.com/zach-klippenstein/goadb/wire" + +type nilSafeDialer struct { + wire.Dialer +} + +func (d nilSafeDialer) Dial() (*wire.Conn, error) { + if d.Dialer == nil { + d.Dialer = wire.NewDialer("", 0) + } + + return d.Dialer.Dial() +} diff --git a/sync_client.go b/sync_client.go new file mode 100644 index 0000000..61b063d --- /dev/null +++ b/sync_client.go @@ -0,0 +1,92 @@ +// TODO(z): Implement send. +package goadb + +import ( + "fmt" + "io" + "os" + "time" + + "github.com/zach-klippenstein/goadb/wire" +) + +/* +DirEntry holds information about a directory entry on a device. + +Unfortunately, adb doesn't seem to set the directory bit for directories. +*/ +type DirEntry struct { + Name string + Mode os.FileMode + Size int32 + ModifiedAt time.Time +} + +func stat(conn *wire.SyncConn, path string) (*DirEntry, error) { + if err := conn.SendOctetString("STAT"); err != nil { + return nil, err + } + if err := conn.SendString(path); err != nil { + return nil, err + } + + id, err := conn.ReadOctetString() + if err != nil { + return nil, err + } + if id != "STAT" { + return nil, fmt.Errorf("expected stat ID 'STAT', but got '%s'", id) + } + + return readStat(conn) +} + +func listDirEntries(conn *wire.SyncConn, path string) (entries *DirEntries, err error) { + if err = conn.SendOctetString("LIST"); err != nil { + return + } + if err = conn.SendString(path); err != nil { + return + } + + return &DirEntries{ + scanner: conn, + doneHandler: func() { conn.Close() }, + }, nil +} + +func receiveFile(conn *wire.SyncConn, path string) (io.ReadCloser, error) { + if err := conn.SendOctetString("RECV"); err != nil { + return nil, err + } + if err := conn.SendString(path); err != nil { + return nil, err + } + + return newSyncFileReader(conn), nil +} + +func readStat(s wire.SyncScanner) (entry *DirEntry, err error) { + mode, err := s.ReadFileMode() + if err != nil { + err = fmt.Errorf("error reading file mode: %v", err) + return + } + size, err := s.ReadInt32() + if err != nil { + err = fmt.Errorf("error reading file size: %v", err) + return + } + mtime, err := s.ReadTime() + if err != nil { + err = fmt.Errorf("error reading file time: %v", err) + return + } + + entry = &DirEntry{ + Mode: mode, + Size: size, + ModifiedAt: mtime, + } + return +} diff --git a/sync_client_test.go b/sync_client_test.go new file mode 100644 index 0000000..d1fde95 --- /dev/null +++ b/sync_client_test.go @@ -0,0 +1,2 @@ +// TODO(z): Implement tests for sync_client functions. +package goadb diff --git a/sync_file_reader.go b/sync_file_reader.go new file mode 100644 index 0000000..6ab1c04 --- /dev/null +++ b/sync_file_reader.go @@ -0,0 +1,69 @@ +package goadb + +import ( + "fmt" + "io" + + "github.com/zach-klippenstein/goadb/wire" +) + +// syncFileReader wraps a SyncConn that has requested to receive a file. +type syncFileReader struct { + // Reader used to read data from the adb connection. + scanner wire.SyncScanner + + // Reader for the current chunk only. + chunkReader io.Reader +} + +var _ io.ReadCloser = &syncFileReader{} + +func newSyncFileReader(s wire.SyncScanner) io.ReadCloser { + return &syncFileReader{ + scanner: s, + } +} + +func (r *syncFileReader) Read(buf []byte) (n int, err error) { + if r.chunkReader == nil { + chunkReader, err := readNextChunk(r.scanner) + if err != nil { + // If this is EOF, we've read the last chunk. + // Either way, we want to pass it up to the caller. + return 0, err + } + r.chunkReader = chunkReader + } + + n, err = r.chunkReader.Read(buf) + if err == io.EOF { + // End of current chunk, don't return an error, the next chunk will be + // read on the next call to this method. + r.chunkReader = nil + return n, nil + } + + return n, err +} + +func (r *syncFileReader) Close() error { + return r.scanner.Close() +} + +// readNextChunk creates an io.LimitedReader for the next chunk of data, +// and returns io.EOF if the last chunk has been read. +func readNextChunk(r wire.SyncScanner) (io.Reader, error) { + id, err := r.ReadOctetString() + if err != nil { + return nil, err + } + + switch id { + case "DATA": + return r.ReadBytes() + case "DONE": + return nil, io.EOF + default: + return nil, fmt.Errorf("expected chunk id 'DATA', but got '%s'", id) + } +} diff --git a/sync_file_reader_test.go b/sync_file_reader_test.go new file mode 100644 index 0000000..cd7bfc5 --- /dev/null +++ b/sync_file_reader_test.go @@ -0,0 +1,82 @@ +package goadb + +import ( + "io" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/zach-klippenstein/goadb/wire" +) + +func TestReadNextChunk(t *testing.T) { + s := wire.NewSyncScanner(strings.NewReader( + "DATA\006\000\000\000hello DATA\005\000\000\000worldDONE")) + + // Read 1st chunk + reader, err := readNextChunk(s) + assert.NoError(t, err) + assert.Equal(t, 6, reader.(*io.LimitedReader).N) + buf := make([]byte, 10) + n, err := reader.Read(buf) + assert.NoError(t, err) + assert.Equal(t, 6, n) + assert.Equal(t, "hello ", string(buf[:6])) + + // Read 2nd chunk + reader, err = readNextChunk(s) + assert.NoError(t, err) + assert.Equal(t, 5, reader.(*io.LimitedReader).N) + buf = make([]byte, 10) + n, err = reader.Read(buf) + assert.NoError(t, err) + assert.Equal(t, 5, n) + assert.Equal(t, "world", string(buf[:5])) + + // Read DONE + _, err = readNextChunk(s) + assert.Equal(t, io.EOF, err) +} +func TestReadNextChunkInvalidChunkId(t *testing.T) { + s := wire.NewSyncScanner(strings.NewReader( + "ATAD\006\000\000\000hello ")) + + // Read 1st chunk + _, err := readNextChunk(s) + assert.EqualError(t, err, "expected chunk id 'DATA', but got 'ATAD'") +} + +func TestReadMultipleCalls(t *testing.T) { + s := wire.NewSyncScanner(strings.NewReader( + "DATA\006\000\000\000hello DATA\005\000\000\000worldDONE")) + reader := newSyncFileReader(s) + + firstByte := make([]byte, 1) + _, err := io.ReadFull(reader, firstByte) + assert.NoError(t, err) + assert.Equal(t, "h", string(firstByte)) + + restFirstChunkBytes := make([]byte, 5) + _, err = io.ReadFull(reader, restFirstChunkBytes) + assert.NoError(t, err) + assert.Equal(t, "ello ", string(restFirstChunkBytes)) + + secondChunkBytes := make([]byte, 5) + _, err = io.ReadFull(reader, secondChunkBytes) + assert.NoError(t, err) + assert.Equal(t, "world", string(secondChunkBytes)) + + _, err = io.ReadFull(reader, make([]byte, 5)) + assert.Equal(t, io.EOF, err) +} + +func TestReadAll(t *testing.T) { + s := wire.NewSyncScanner(strings.NewReader( + "DATA\006\000\000\000hello DATA\005\000\000\000worldDONE")) + reader := newSyncFileReader(s) + + buf := make([]byte, 20) + _, err := io.ReadFull(reader, buf) + assert.Equal(t, io.ErrUnexpectedEOF, err) + assert.Equal(t, "hello world\000", string(buf[:12])) +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..e6f7c19 --- /dev/null +++ b/util.go @@ -0,0 +1,18 @@ +package goadb + +import "strings" +import ( + "regexp" +) + +var ( + whitespaceRegex = regexp.MustCompile(`^\s*$`) +) + +func containsWhitespace(str string) bool { + return strings.ContainsAny(str, " \t\v") +} + +func isBlank(str string) bool { + return whitespaceRegex.MatchString(str) +} diff --git a/util_test.go b/util_test.go new file mode 100644 index 0000000..66aab99 --- /dev/null +++ b/util_test.go @@ -0,0 +1,27 @@ +package goadb + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestContainsWhitespaceYes(t *testing.T) { + assert.True(t, containsWhitespace("hello world")) +} + +func TestContainsWhitespaceNo(t *testing.T) { + assert.False(t, containsWhitespace("hello")) +} + +func TestIsBlankWhenEmpty(t *testing.T) { + assert.True(t, isBlank("")) +} + +func TestIsBlankWhenJustWhitespace(t *testing.T) { + assert.True(t, isBlank(" \t")) +} + +func TestIsBlankNo(t *testing.T) { + assert.False(t, isBlank(" h ")) +} diff --git a/wire/adb_error.go b/wire/adb_error.go new file mode 100644 index 0000000..a20b871 --- /dev/null +++ b/wire/adb_error.go @@ -0,0 +1,24 @@ +package wire + +import ( + "fmt" +) + +type AdbError struct { + Request []byte + ServerMsg string +} + +var _ error = &AdbError{} + +func (e *AdbError) Error() string { + if e.Request == nil { + return fmt.Sprintf("server error: %s", e.ServerMsg) + } else { + return fmt.Sprintf("server error for request '%s': %s", e.Request, e.ServerMsg) + } +} + +func incompleteMessage(description string, actual int, expected int) error { + return fmt.Errorf("incomplete %s: read %d bytes, expecting %d", description, actual, expected) +} diff --git a/wire/conn.go b/wire/conn.go index 7a8baa2..be49829 100644 --- a/wire/conn.go +++ b/wire/conn.go @@ -1,13 +1,18 @@ -/* -The wire package implements the low-level part of the client/server wire protocol. +package wire -The protocol spec can be found at -https://android.googlesource.com/platform/system/core/+/master/adb/OVERVIEW.TXT. +const ( + // The official implementation of adb imposes an undocumented 255-byte limit + // on messages. + MaxMessageLength = 255 +) + +/* +Conn is a normal connection to an adb server. For most cases, usage looks something like: conn := wire.Dial() conn.SendMessage(data) - conn.ReadStatus() == "OKAY" || "FAIL" + conn.ReadStatus() == StatusSuccess || StatusFailure conn.ReadMessage() conn.Close() @@ -18,52 +23,44 @@ it returns an io.EOF error. For most commands, the server will close the connection after sending the response. You should still always call Close() when you're done with the connection. */ -package wire - -import ( - "fmt" - "io" - "net" -) - -const ( - // Default port the adb server listens on. - AdbPort = 5037 - - // The official implementation of adb imposes an undocumented 255-byte limit - // on messages. - MaxMessageLength = 255 -) - -// Conn is a connection to an adb server. type Conn struct { Scanner Sender - io.Closer + closer func() error } -// Dial connects to the adb server on the default port, AdbPort. -func Dial() (*Conn, error) { - return DialPort(AdbPort) +func NewConn(scanner Scanner, sender Sender, closer func() error) *Conn { + return &Conn{scanner, sender, closer} } -// Dial connects to the adb server on port. -func DialPort(port int) (*Conn, error) { - return DialAddr(fmt.Sprintf("localhost:%d", port)) +// Close closes the underlying connection. +func (c *Conn) Close() error { + if c.closer != nil { + return c.closer() + } + return nil } -// Dial connects to the adb server at address. -func DialAddr(address string) (*Conn, error) { - netConn, err := net.Dial("tcp", address) - if err != nil { +// NewSyncConn returns connection that can operate in sync mode. +// The connection must already have been switched (by sending the sync command +// to a specific device), or the return connection will return an error. +func (c *Conn) NewSyncConn() *SyncConn { + return &SyncConn{ + SyncScanner: c.Scanner.NewSyncScanner(), + SyncSender: c.Sender.NewSyncSender(), + } +} + +// RoundTripSingleResponse sends a message to the server, and reads a single +// message response. If the reponse has a failure status code, returns it as an error. +func (conn *Conn) RoundTripSingleResponse(req []byte) (resp []byte, err error) { + if err = conn.SendMessage(req); err != nil { return nil, err } - return &Conn{ - Scanner: NewScanner(netConn), - Sender: NewSender(netConn), - Closer: netConn, - }, nil -} + if err = ReadStatusFailureAsError(conn, req); err != nil { + return nil, err + } -var _ io.Closer = &Conn{} + return conn.ReadMessage() +} diff --git a/wire/dialer.go b/wire/dialer.go new file mode 100644 index 0000000..ccf883d --- /dev/null +++ b/wire/dialer.go @@ -0,0 +1,70 @@ +package wire + +import ( + "fmt" + "net" + "runtime" +) + +const ( + // Default port the adb server listens on. + AdbPort = 5037 +) + +/* +Dialer knows how to create connections to an adb server. +*/ +type Dialer interface { + Dial() (*Conn, error) +} + +type netDialer struct { + Host string + Port int +} + +func NewDialer(host string, port int) Dialer { + return &netDialer{host, port} +} + +// Dial connects to the adb server on the host and port set on the netDialer. +// The zero-value will connect to the default, localhost:5037. +func (d *netDialer) Dial() (*Conn, error) { + host := d.Host + if host == "" { + host = "localhost" + } + + port := d.Port + if port == 0 { + port = AdbPort + } + + netConn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", host, port)) + if err != nil { + return nil, err + } + + conn := &Conn{ + Scanner: NewScanner(netConn), + Sender: NewSender(netConn), + closer: netConn.Close, + } + + // Prevent leaking the network connection, not sure if TCPConn does this itself. + runtime.SetFinalizer(netConn, func(conn *net.TCPConn) { + conn.Close() + }) + + return conn, nil +} + +func RoundTripSingleResponse(d Dialer, req string) ([]byte, error) { + conn, err := d.Dial() + if err != nil { + return nil, err + } + defer conn.Close() + + return conn.RoundTripSingleResponse([]byte(req)) +} diff --git a/wire/doc.go b/wire/doc.go new file mode 100644 index 0000000..a4293a1 --- /dev/null +++ b/wire/doc.go @@ -0,0 +1,13 @@ +/* +Package wire implements the low-level part of the client/server wire protocol. +It also implements the "sync" wire format for file transfers. + +This package is not intended to be used directly. goadb.HostClient and goadb.DeviceClient +use it to abstract away the bit-twiddling details of the protocol. You should only ever +need to work with the goadb package. Also, this package's API may change more frequently +than goadb's. + +The protocol spec can be found at +https://android.googlesource.com/platform/system/core/+/master/adb/OVERVIEW.TXT. +*/ +package wire diff --git a/wire/scanner.go b/wire/scanner.go index 6902cf1..e8ffa32 100644 --- a/wire/scanner.go +++ b/wire/scanner.go @@ -3,6 +3,7 @@ package wire import ( "fmt" "io" + "io/ioutil" "strconv" ) @@ -27,6 +28,8 @@ See Conn for more details. type Scanner interface { ReadStatus() (StatusCode, error) ReadMessage() ([]byte, error) + ReadUntilEof() ([]byte, error) + NewSyncScanner() SyncScanner } type realScanner struct { @@ -73,6 +76,38 @@ func (s *realScanner) ReadMessage() ([]byte, error) { return data, nil } +func (s *realScanner) ReadUntilEof() ([]byte, error) { + return ioutil.ReadAll(s.reader) +} + +func (s *realScanner) NewSyncScanner() SyncScanner { + return NewSyncScanner(s.reader) +} + +// Reads the status, and if failure, reads the message and returns it as an error. +// If the status is success, doesn't read the message. +// req is just used to populate the AdbError, and can be nil. +func ReadStatusFailureAsError(s Scanner, req []byte) error { + status, err := s.ReadStatus() + if err != nil { + return err + } + + if !status.IsSuccess() { + msg, err := s.ReadMessage() + if err != nil { + return err + } + + return &AdbError{ + Request: req, + ServerMsg: string(msg), + } + } + + return nil +} + func (s *realScanner) readLength() (int, error) { lengthHex := make([]byte, 4) n, err := io.ReadFull(s.reader, lengthHex) @@ -95,8 +130,4 @@ func (s *realScanner) readLength() (int, error) { return int(length), nil } -func incompleteMessage(description string, actual int, expected int) error { - return fmt.Errorf("incomplete %s: read %d bytes, expecting %d", description, actual, expected) -} - var _ Scanner = &realScanner{} diff --git a/wire/sender.go b/wire/sender.go index 8ee3c4e..79444d6 100644 --- a/wire/sender.go +++ b/wire/sender.go @@ -8,6 +8,7 @@ import ( // Sender sends messages to the server. type Sender interface { SendMessage(msg []byte) error + NewSyncSender() SyncSender } type realSender struct { @@ -31,15 +32,8 @@ func (s *realSender) SendMessage(msg []byte) error { return writeFully(s.writer, []byte(lengthAndMsg)) } -func writeFully(w io.Writer, data []byte) error { - for len(data) > 0 { - n, err := w.Write(data) - if err != nil { - return err - } - data = data[n:] - } - return nil +func (s *realSender) NewSyncSender() SyncSender { + return NewSyncSender(s.writer) } var _ Sender = &realSender{} diff --git a/wire/sync.go b/wire/sync.go new file mode 100644 index 0000000..d69f63f --- /dev/null +++ b/wire/sync.go @@ -0,0 +1,198 @@ +// TODO(z): Write SyncSender.SendBytes(). +package wire + +import ( + "encoding/binary" + "fmt" + "io" + "os" + "time" +) + +const ( + // Chunks cannot be longer than 64k. + MaxChunkSize = 64 * 1024 +) + +/* +SyncConn is a connection to the adb server in sync mode. +Assumes the connection has been put into sync mode (by sending "sync" in transport mode). + +The adb sync protocol is defined at +https://android.googlesource.com/platform/system/core/+/master/adb/SYNC.TXT. + +Unlike the normal adb protocol (implemented in Conn), the sync protocol is binary. +Lengths are binary-encoded (little-endian) instead of hex. + +Notes on Encoding + +Length headers and other integers are encoded in little-endian, with 32 bits. + +File mode seems to be encoded as POSIX file mode. + +Modification time seems to be the Unix timestamp format, i.e. seconds since Epoch UTC. +*/ +type SyncConn struct { + SyncScanner + SyncSender +} + +func (c *SyncConn) Close() error { + return c.SyncScanner.Close() +} + +type SyncScanner interface { + // ReadOctetString reads a 4-byte string. + ReadOctetString() (string, error) + ReadInt32() (int32, error) + ReadFileMode() (os.FileMode, error) + ReadTime() (time.Time, error) + + // Reads an octet length, followed by length bytes. + ReadString() (string, error) + + // Reads an octet length, and returns a reader that will read length + // bytes (see io.LimitReader). The returned reader should be fully + // read before reading anything off the Scanner again. + ReadBytes() (io.Reader, error) + + // Closes the underlying reader. + Close() error +} + +type SyncSender interface { + // SendOctetString sends a 4-byte string. + SendOctetString(string) error + SendInt32(int32) error + SendFileMode(os.FileMode) error + SendTime(time.Time) error + + // Sends len(bytes) as an octet, followed by bytes. + SendString(str string) error +} + +type realSyncScanner struct { + io.Reader +} + +type realSyncSender struct { + io.Writer +} + +func NewSyncScanner(r io.Reader) SyncScanner { + return &realSyncScanner{r} +} + +func NewSyncSender(w io.Writer) SyncSender { + return &realSyncSender{w} +} + +func RequireOctetString(s SyncScanner, expected string) error { + actual, err := s.ReadOctetString() + if err != nil { + return fmt.Errorf("expected to read '%s', got err: %v", expected, err) + } + if actual != expected { + return fmt.Errorf("expected to read '%s', got '%s'", expected, actual) + } + return nil +} + +func (s *realSyncScanner) ReadOctetString() (string, error) { + octet := make([]byte, 4) + n, err := io.ReadFull(s.Reader, octet) + if err != nil && err != io.ErrUnexpectedEOF { + return "", err + } else if err == io.ErrUnexpectedEOF { + return "", incompleteMessage("octet", n, 4) + } + + return string(octet), nil +} + +func (s *realSyncSender) SendOctetString(str string) error { + if len(str) != 4 { + return fmt.Errorf("octet string must be exactly 4 bytes: '%s'", str) + } + return writeFully(s.Writer, []byte(str)) +} + +func (s *realSyncScanner) ReadInt32() (int32, error) { + var value int32 + err := binary.Read(s.Reader, binary.LittleEndian, &value) + return value, err +} + +func (s *realSyncSender) SendInt32(val int32) error { + return binary.Write(s.Writer, binary.LittleEndian, val) +} + +func (s *realSyncScanner) ReadFileMode() (os.FileMode, error) { + var value uint32 + err := binary.Read(s.Reader, binary.LittleEndian, &value) + return os.FileMode(value), err +} + +func (s *realSyncSender) SendFileMode(mode os.FileMode) error { + return binary.Write(s.Writer, binary.LittleEndian, mode) +} + +func (s *realSyncScanner) ReadTime() (time.Time, error) { + seconds, err := s.ReadInt32() + if err != nil { + return time.Time{}, err + } + + return time.Unix(int64(seconds), 0).UTC(), nil +} + +func (s *realSyncSender) SendTime(t time.Time) error { + return s.SendInt32(int32(t.Unix())) +} + +func (s *realSyncScanner) ReadString() (string, error) { + length, err := s.ReadInt32() + if err != nil { + return "", err + } + + bytes := make([]byte, length) + n, err := io.ReadFull(s.Reader, bytes) + if err != nil && err != io.ErrUnexpectedEOF { + return "", err + } else if err == io.ErrUnexpectedEOF { + return "", incompleteMessage("bytes", n, int(length)) + } + + return string(bytes), nil +} + +func (s *realSyncSender) SendString(str string) error { + length := len(str) + if length > MaxChunkSize { + // This limit might not apply to filenames, but it's big enough + // that I don't think it will be a problem. + return fmt.Errorf("str must be <= %d in length", MaxChunkSize) + } + + if err := s.SendInt32(int32(length)); err != nil { + return err + } + return writeFully(s.Writer, []byte(str)) +} + +func (s *realSyncScanner) ReadBytes() (io.Reader, error) { + length, err := s.ReadInt32() + if err != nil { + return nil, err + } + + return io.LimitReader(s.Reader, int64(length)), nil +} + +func (s *realSyncScanner) Close() error { + if closer, ok := s.Reader.(io.Closer); ok { + return closer.Close() + } + return nil +} diff --git a/wire/sync_test.go b/wire/sync_test.go new file mode 100644 index 0000000..d11a522 --- /dev/null +++ b/wire/sync_test.go @@ -0,0 +1,87 @@ +package wire + +import ( + "bytes" + "io/ioutil" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +var ( + someTime = time.Date(2015, 04, 12, 20, 7, 51, 0, time.UTC) + // The little-endian encoding of someTime.Unix() + someTimeEncoded = []byte{151, 208, 42, 85} +) + +func TestSyncReadOctetString(t *testing.T) { + s := NewSyncScanner(strings.NewReader("helo")) + str, err := s.ReadOctetString() + assert.NoError(t, err) + assert.Equal(t, "helo", str) +} + +func TestSyncSendOctetString(t *testing.T) { + var buf bytes.Buffer + s := NewSyncSender(&buf) + err := s.SendOctetString("helo") + assert.NoError(t, err) + assert.Equal(t, "helo", buf.String()) +} + +func TestSyncSendOctetStringTooLong(t *testing.T) { + var buf bytes.Buffer + s := NewSyncSender(&buf) + err := s.SendOctetString("hello") + assert.EqualError(t, err, "octet string must be exactly 4 bytes: 'hello'") +} + +func TestSyncReadTime(t *testing.T) { + s := NewSyncScanner(bytes.NewReader(someTimeEncoded)) + decoded, err := s.ReadTime() + assert.NoError(t, err) + assert.Equal(t, someTime, decoded) +} + +func TestSyncSendTime(t *testing.T) { + var buf bytes.Buffer + s := NewSyncSender(&buf) + err := s.SendTime(someTime) + assert.NoError(t, err) + assert.Equal(t, someTimeEncoded, buf.Bytes()) +} + +func TestSyncReadString(t *testing.T) { + s := NewSyncScanner(strings.NewReader("\005\000\000\000hello")) + str, err := s.ReadString() + assert.NoError(t, err) + assert.Equal(t, "hello", str) +} + +func TestSyncReadStringTooShort(t *testing.T) { + s := NewSyncScanner(strings.NewReader("\005\000\000\000h")) + _, err := s.ReadString() + assert.EqualError(t, err, "incomplete bytes: read 1 bytes, expecting 5") +} + +func TestSyncSendString(t *testing.T) { + var buf bytes.Buffer + s := NewSyncSender(&buf) + err := s.SendString("hello") + assert.NoError(t, err) + assert.Equal(t, "\005\000\000\000hello", buf.String()) +} + +func TestSyncReadBytes(t *testing.T) { + s := NewSyncScanner(strings.NewReader("\005\000\000\000helloworld")) + + reader, err := s.ReadBytes() + assert.NoError(t, err) + assert.NotNil(t, reader) + + str, err := ioutil.ReadAll(reader) + assert.NoError(t, err) + assert.Equal(t, "hello", string(str)) +} diff --git a/wire/util.go b/wire/util.go new file mode 100644 index 0000000..15fd32f --- /dev/null +++ b/wire/util.go @@ -0,0 +1,16 @@ +package wire + +import "io" + +// writeFully writes all of data to w. +// Inverse of io.ReadFully(). +func writeFully(w io.Writer, data []byte) error { + for len(data) > 0 { + n, err := w.Write(data) + if err != nil { + return err + } + data = data[n:] + } + return nil +}