commit 8d46dec61dad57b2e5691a73f2726d04366a4f79 Author: Zach Klippenstein Date: Sat Apr 11 14:45:29 2015 -0700 Initial commit – wire format implementation and some host client commands. diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..9c8f3ea --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..0041124 --- /dev/null +++ b/README.md @@ -0,0 +1,5 @@ +#goadb + +[![GoDoc](https://godoc.org/github.com/zach-klippenstein/goadb?status.svg)](https://godoc.org/github.com/zach-klippenstein/goadb) + +A Golang library for interacting with the Android Debug Bridge (adb). \ No newline at end of file diff --git a/cmd/demo/demo.go b/cmd/demo/demo.go new file mode 100644 index 0000000..cfd6647 --- /dev/null +++ b/cmd/demo/demo.go @@ -0,0 +1,33 @@ +package main + +import ( + "fmt" + "log" + + adb "github.com/zach-klippenstein/goadb" + "github.com/zach-klippenstein/goadb/wire" +) + +func main() { + client := &adb.HostClient{wire.Dial} + fmt.Println("Starting server…") + client.StartServer() + + serverVersion, err := client.GetServerVersion() + if err != nil { + log.Fatal(err) + } + fmt.Println("Server version:", serverVersion) + + devices, err := client.ListDevices() + if err != nil { + log.Fatal(err) + } + fmt.Println("Devices:") + for _, device := range devices { + fmt.Printf("\t%+v\n", *device) + } + + fmt.Println("Killing server…") + client.KillServer() +} diff --git a/cmd/raw-adb/raw-adb.go b/cmd/raw-adb/raw-adb.go new file mode 100644 index 0000000..e9a06af --- /dev/null +++ b/cmd/raw-adb/raw-adb.go @@ -0,0 +1,78 @@ +package main + +import ( + "bufio" + "flag" + "fmt" + "io" + "log" + "os" + "strings" + + "github.com/zach-klippenstein/goadb/wire" +) + +var port = flag.Int("p", wire.AdbPort, "port the adb server is listening on") + +func main() { + flag.Parse() + + fmt.Println("using port", *port) + + printServerVersion() + + for { + line := readLine() + err := doCommand(line) + if err != nil { + fmt.Println("error:", err) + } + } +} + +func printServerVersion() { + err := doCommand("host:version") + if err != nil { + log.Fatal(err) + } +} + +func readLine() string { + fmt.Print("> ") + line, err := bufio.NewReader(os.Stdin).ReadString('\n') + if err != nil && err != io.EOF { + log.Fatal(err) + } + return strings.TrimSpace(line) +} + +func doCommand(cmd string) error { + conn, err := wire.DialPort(*port) + if err != nil { + log.Fatal(err) + } + defer conn.Close() + + if err := wire.SendMessageString(conn, cmd); err != nil { + return err + } + + status, err := conn.ReadStatus() + if err != nil { + return err + } + + var msg string + for err == nil { + msg, err = wire.ReadMessageString(conn) + if err == nil { + fmt.Printf("%s> %s\n", status, msg) + } + } + + if err != io.EOF { + return err + } + + return nil +} diff --git a/devices.go b/devices.go new file mode 100644 index 0000000..6521c57 --- /dev/null +++ b/devices.go @@ -0,0 +1,92 @@ +package goadb + +import ( + "bufio" + "fmt" + "strings" +) + +// Device represents a connected Android device. +type Device struct { + // Always set. + Serial string + + // Product, device, and model are not set in the short form. + Product string + Model string + Device string + + // Only set for devices connected via USB. + Usb string +} + +// IsUsb returns true if the device is connected via USB. +func (d *Device) IsUsb() bool { + return d.Usb != "" +} + +func newDevice(serial string, attrs map[string]string) (*Device, error) { + if serial == "" { + return nil, fmt.Errorf("device serial cannot be blank") + } + + return &Device{ + Serial: serial, + Product: attrs["product"], + Model: attrs["model"], + Device: attrs["device"], + Usb: attrs["usb"], + }, nil +} + +func parseDeviceList(list string, lineParseFunc func(string) (*Device, error)) ([]*Device, error) { + var devices []*Device + scanner := bufio.NewScanner(strings.NewReader(list)) + + for scanner.Scan() { + device, err := lineParseFunc(scanner.Text()) + if err != nil { + return nil, err + } + devices = append(devices, device) + } + if err := scanner.Err(); err != nil { + return nil, err + } + + return devices, nil +} + +func parseDeviceShort(line string) (*Device, error) { + fields := strings.Fields(line) + if len(fields) != 2 { + return nil, fmt.Errorf("malformed device line, expected 2 fields but found %d", len(fields)) + } + + return newDevice(fields[0], map[string]string{}) +} + +func parseDeviceLong(line string) (*Device, error) { + fields := strings.Fields(line) + if len(fields) < 5 { + return nil, fmt.Errorf("malformed device line, expected at least 5 fields but found %d", len(fields)) + } + + attrs := parseDeviceAttributes(fields[2:]) + return newDevice(fields[0], attrs) +} + +func parseDeviceAttributes(fields []string) map[string]string { + attrs := map[string]string{} + for _, field := range fields { + key, val := parseKeyVal(field) + attrs[key] = val + } + return attrs +} + +// Parses a key:val pair and returns key, val. +func parseKeyVal(pair string) (string, string) { + split := strings.Split(pair, ":") + return split[0], split[1] +} diff --git a/devices_test.go b/devices_test.go new file mode 100644 index 0000000..2920efe --- /dev/null +++ b/devices_test.go @@ -0,0 +1,45 @@ +package goadb + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func ParseDeviceList(t *testing.T) { + devs, err := parseDeviceList(`192.168.56.101:5555 device +05856558`, parseDeviceShort) + + assert.NoError(t, err) + assert.Len(t, devs, 2) + assert.Equal(t, "192.168.56.101:5555", devs[0].Serial) + assert.Equal(t, "05856558", devs[1].Serial) +} + +func TestParseDeviceShort(t *testing.T) { + dev, err := parseDeviceShort("192.168.56.101:5555 device\n") + assert.NoError(t, err) + assert.Equal(t, &Device{ + Serial: "192.168.56.101:5555"}, dev) +} + +func TestParseDeviceLong(t *testing.T) { + dev, err := parseDeviceLong("SERIAL device product:PRODUCT model:MODEL device:DEVICE\n") + assert.NoError(t, err) + assert.Equal(t, &Device{ + Serial: "SERIAL", + Product: "PRODUCT", + Model: "MODEL", + Device: "DEVICE"}, dev) +} + +func TestParseDeviceLongUsb(t *testing.T) { + dev, err := parseDeviceLong("SERIAL device usb:1234 product:PRODUCT model:MODEL device:DEVICE \n") + assert.NoError(t, err) + assert.Equal(t, &Device{ + Serial: "SERIAL", + Product: "PRODUCT", + Model: "MODEL", + Device: "DEVICE", + Usb: "1234"}, dev) +} diff --git a/host_client.go b/host_client.go new file mode 100644 index 0000000..4d27fdc --- /dev/null +++ b/host_client.go @@ -0,0 +1,156 @@ +/* +Go interface to the Android Debug Bridge (adb). + +The client/server spec is defined at https://android.googlesource.com/platform/system/core/+/master/adb/OVERVIEW.TXT. + +WARNING This library is under heavy development, and its API is likely to change without notice. +*/ +package goadb + +import ( + "fmt" + "os/exec" + "strconv" + + "github.com/zach-klippenstein/goadb/wire" +) + +// Dialer is a function that knows how to create a connection to an adb server. +type Dialer func() (*wire.Conn, error) + +/* +Interacts with host services on the adb server. + +Eg. + dialer := &HostClient{wire.Dial} + dialer.GetServerVersion() + +TODO make this a real example. + +TODO Finish implementing services. + +See list of services at https://android.googlesource.com/platform/system/core/+/master/adb/SERVICES.TXT. +*/ +type HostClient struct { + Dialer +} + +// GetServerVersion asks the ADB server for its internal version number. +func (c *HostClient) GetServerVersion() (int, error) { + resp, err := c.roundTripSingleResponse([]byte("host:version")) + if err != nil { + return 0, err + } + + version, err := strconv.ParseInt(string(resp), 16, 32) + return int(version), err +} + +/* +KillServer tells the server to quit immediately. + +Corresponds to the command: + adb kill-server +*/ +func (c *HostClient) KillServer() error { + conn, err := c.Dialer() + if err != nil { + return err + } + defer conn.Close() + + if err = conn.SendMessage([]byte("host:kill")); err != nil { + return err + } + + return nil +} + +/* +StartServer ensures there is a server running. + +Currently implemented by just running + adb start-server +*/ +func (c *HostClient) StartServer() error { + cmd := exec.Command("adb", "start-server") + return cmd.Run() +} + +/* +ListDeviceSerials returns the serial numbers of all attached devices. + +Corresponds to the command: + adb devices +*/ +func (c *HostClient) ListDeviceSerials() ([]string, error) { + resp, err := c.roundTripSingleResponse([]byte("host:devices")) + if err != nil { + return nil, err + } + + devices, err := parseDeviceList(string(resp), parseDeviceShort) + if err != nil { + return nil, err + } + + serials := make([]string, len(devices)) + for i, dev := range devices { + serials[i] = dev.Serial + } + return serials, nil +} + +/* +ListDevices returns the list of connected devices. + +Corresponds to the command: + adb devices -l +*/ +func (c *HostClient) ListDevices() ([]*Device, error) { + resp, err := c.roundTripSingleResponse([]byte("host:devices-l")) + if err != nil { + return nil, err + } + + return parseDeviceList(string(resp), parseDeviceLong) +} + +func (c *HostClient) roundTripSingleResponse(req []byte) (resp []byte, err error) { + conn, err := c.Dialer() + if err != nil { + return nil, err + } + defer conn.Close() + + if err = conn.SendMessage(req); err != nil { + return nil, err + } + + err = c.readStatusFailureAsError(conn) + if err != nil { + return nil, err + } + + return conn.ReadMessage() +} + +// Reads the status, and if failure, reads the message and returns it as an error. +// If the status is success, doesn't read the message. +func (c *HostClient) readStatusFailureAsError(conn *wire.Conn) error { + status, err := conn.ReadStatus() + if err != nil { + return err + } + + if !status.IsSuccess() { + msg, err := conn.ReadMessage() + if err != nil { + return err + } + + return fmt.Errorf("server error: %s", msg) + } + + return nil +} diff --git a/host_client_test.go b/host_client_test.go new file mode 100644 index 0000000..754f0c2 --- /dev/null +++ b/host_client_test.go @@ -0,0 +1,54 @@ +package goadb + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/zach-klippenstein/goadb/wire" +) + +func TestGetServerVersion(t *testing.T) { + client := &HostClient{mockDialer(&MockServer{ + Status: wire.StatusSuccess, + Messages: []string{"000a"}, + })} + + v, err := client.GetServerVersion() + assert.NoError(t, err) + assert.Equal(t, 10, v) +} + +func mockDialer(s *MockServer) Dialer { + return func() (*wire.Conn, error) { + return &wire.Conn{s, s, s}, nil + } +} + +type MockServer struct { + Status wire.StatusCode + Messages []string + + nextMsgIndex int +} + +func (s *MockServer) ReadStatus() (wire.StatusCode, error) { + return s.Status, nil +} + +func (s *MockServer) ReadMessage() ([]byte, error) { + if s.nextMsgIndex >= len(s.Messages) { + return nil, io.EOF + } + + s.nextMsgIndex++ + return []byte(s.Messages[s.nextMsgIndex-1]), nil +} + +func (s *MockServer) SendMessage(msg []byte) error { + return nil +} + +func (s *MockServer) Close() error { + return nil +} diff --git a/wire/conn.go b/wire/conn.go new file mode 100644 index 0000000..7a8baa2 --- /dev/null +++ b/wire/conn.go @@ -0,0 +1,69 @@ +/* +The wire package implements the low-level part of the client/server wire protocol. + +The protocol spec can be found at +https://android.googlesource.com/platform/system/core/+/master/adb/OVERVIEW.TXT. + +For most cases, usage looks something like: + conn := wire.Dial() + conn.SendMessage(data) + conn.ReadStatus() == "OKAY" || "FAIL" + conn.ReadMessage() + conn.Close() + +For some messages, the server will return more than one message (but still a single +status). Generally, after calling ReadStatus once, you should call ReadMessage until +it returns an io.EOF error. + +For most commands, the server will close the connection after sending the response. +You should still always call Close() when you're done with the connection. +*/ +package wire + +import ( + "fmt" + "io" + "net" +) + +const ( + // Default port the adb server listens on. + AdbPort = 5037 + + // The official implementation of adb imposes an undocumented 255-byte limit + // on messages. + MaxMessageLength = 255 +) + +// Conn is a connection to an adb server. +type Conn struct { + Scanner + Sender + io.Closer +} + +// Dial connects to the adb server on the default port, AdbPort. +func Dial() (*Conn, error) { + return DialPort(AdbPort) +} + +// Dial connects to the adb server on port. +func DialPort(port int) (*Conn, error) { + return DialAddr(fmt.Sprintf("localhost:%d", port)) +} + +// Dial connects to the adb server at address. +func DialAddr(address string) (*Conn, error) { + netConn, err := net.Dial("tcp", address) + if err != nil { + return nil, err + } + + return &Conn{ + Scanner: NewScanner(netConn), + Sender: NewSender(netConn), + Closer: netConn, + }, nil +} + +var _ io.Closer = &Conn{} diff --git a/wire/scanner.go b/wire/scanner.go new file mode 100644 index 0000000..6902cf1 --- /dev/null +++ b/wire/scanner.go @@ -0,0 +1,102 @@ +package wire + +import ( + "fmt" + "io" + "strconv" +) + +// StatusCodes are returned by the server. If the code indicates failure, the +// next message will be the error. +type StatusCode string + +const ( + StatusSuccess StatusCode = "OKAY" + StatusFailure = "FAIL" + StatusNone = "" +) + +func (status StatusCode) IsSuccess() bool { + return status == StatusSuccess +} + +/* +Scanner reads tokens from a server. +See Conn for more details. +*/ +type Scanner interface { + ReadStatus() (StatusCode, error) + ReadMessage() ([]byte, error) +} + +type realScanner struct { + reader io.Reader +} + +func NewScanner(r io.Reader) Scanner { + return &realScanner{r} +} + +func ReadMessageString(s Scanner) (string, error) { + msg, err := s.ReadMessage() + if err != nil { + return string(msg), err + } + return string(msg), nil +} + +func (s *realScanner) ReadStatus() (StatusCode, error) { + status := make([]byte, 4) + n, err := io.ReadFull(s.reader, status) + if err != nil && err != io.ErrUnexpectedEOF { + return "", err + } else if err == io.ErrUnexpectedEOF { + return StatusCode(status), incompleteMessage("status", n, 4) + } + + return StatusCode(status), nil +} + +func (s *realScanner) ReadMessage() ([]byte, error) { + length, err := s.readLength() + if err != nil { + return nil, err + } + + data := make([]byte, length) + n, err := io.ReadFull(s.reader, data) + if err != nil && err != io.ErrUnexpectedEOF { + return data, fmt.Errorf("error reading message data: %v", err) + } else if err == io.ErrUnexpectedEOF { + return data, incompleteMessage("message data", n, length) + } + return data, nil +} + +func (s *realScanner) readLength() (int, error) { + lengthHex := make([]byte, 4) + n, err := io.ReadFull(s.reader, lengthHex) + if err != nil && err != io.ErrUnexpectedEOF { + return 0, err + } else if err == io.ErrUnexpectedEOF { + return 0, incompleteMessage("length", n, 4) + } + + length, err := strconv.ParseInt(string(lengthHex), 16, 64) + if err != nil { + return 0, fmt.Errorf("invalid hex length: %v", err) + } + + // Clip the length to 255, as per the Google implementation. + if length > MaxMessageLength { + length = MaxMessageLength + } + + return int(length), nil +} + +func incompleteMessage(description string, actual int, expected int) error { + return fmt.Errorf("incomplete %s: read %d bytes, expecting %d", description, actual, expected) +} + +var _ Scanner = &realScanner{} diff --git a/wire/scanner_test.go b/wire/scanner_test.go new file mode 100644 index 0000000..cf958e2 --- /dev/null +++ b/wire/scanner_test.go @@ -0,0 +1,106 @@ +package wire + +import ( + "bufio" + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestReadStatusOkay(t *testing.T) { + s := NewScannerString("OKAYd") + status, err := s.ReadStatus() + assert.NoError(t, err) + assert.True(t, status.IsSuccess()) + assertNotEof(t, s) +} + +func TestReadIncompleteStatus(t *testing.T) { + s := NewScannerString("oka") + _, err := s.ReadStatus() + assert.Equal(t, incompleteMessage("status", 3, 4), err) + assertEof(t, s) +} + +func TestReadLength(t *testing.T) { + s := NewScannerString("000a") + l, err := s.readLength() + assert.NoError(t, err) + assert.Equal(t, 10, l) + assertEof(t, s) +} + +func TestReadIncompleteLength(t *testing.T) { + s := NewScannerString("aaa") + _, err := s.readLength() + assert.Equal(t, incompleteMessage("length", 3, 4), err) + assertEof(t, s) +} + +func TestReadMessage(t *testing.T) { + s := NewScannerString("0005hello") + msg, err := ReadMessageString(s) + assert.NoError(t, err) + assert.Len(t, msg, 5) + assert.Equal(t, "hello", msg) + assertEof(t, s) +} + +func TestReadMessageWithExtraData(t *testing.T) { + s := NewScannerString("0005hellothere") + msg, err := ReadMessageString(s) + assert.NoError(t, err) + assert.Len(t, msg, 5) + assert.Equal(t, "hello", msg) + assertNotEof(t, s) +} + +func TestReadLongerMessage(t *testing.T) { + s := NewScannerString("001b192.168.56.101:5555 device\n") + msg, err := ReadMessageString(s) + assert.NoError(t, err) + assert.Len(t, msg, 27) + assert.Equal(t, "192.168.56.101:5555 device\n", msg) + assertEof(t, s) +} + +func TestReadEmptyMessage(t *testing.T) { + s := NewScannerString("0000") + msg, err := ReadMessageString(s) + assert.NoError(t, err) + assert.Equal(t, "", msg) + assertEof(t, s) +} + +func TestReadIncompleteMessage(t *testing.T) { + s := NewScannerString("0005hel") + msg, err := ReadMessageString(s) + assert.Error(t, err) + assert.Equal(t, incompleteMessage("message data", 3, 5), err) + assert.Equal(t, "hel\000\000", msg) + assertEof(t, s) +} + +func NewScannerString(str string) *realScanner { + return NewScanner(NewEofBuffer(str)).(*realScanner) +} + +// NewEofBuffer returns a bytes.Buffer of str that returns an EOF error +// at the end of input, instead of just returning 0 bytes read. +func NewEofBuffer(str string) *bufio.Reader { + return bufio.NewReader(io.LimitReader(bytes.NewBufferString(str), int64(len(str)))) +} + +func assertEof(t *testing.T, s *realScanner) { + msg, err := s.ReadMessage() + assert.Equal(t, io.EOF, err) + assert.Nil(t, msg) +} + +func assertNotEof(t *testing.T, s *realScanner) { + n, err := s.reader.Read(make([]byte, 1)) + assert.Equal(t, 1, n) + assert.NoError(t, err) +} diff --git a/wire/sender.go b/wire/sender.go new file mode 100644 index 0000000..8ee3c4e --- /dev/null +++ b/wire/sender.go @@ -0,0 +1,45 @@ +package wire + +import ( + "fmt" + "io" +) + +// Sender sends messages to the server. +type Sender interface { + SendMessage(msg []byte) error +} + +type realSender struct { + writer io.Writer +} + +func NewSender(w io.Writer) Sender { + return &realSender{w} +} + +func SendMessageString(s Sender, msg string) error { + return s.SendMessage([]byte(msg)) +} + +func (s *realSender) SendMessage(msg []byte) error { + if len(msg) > MaxMessageLength { + return fmt.Errorf("message length exceeds maximum: %d", len(msg)) + } + + lengthAndMsg := fmt.Sprintf("%04x%s", len(msg), msg) + return writeFully(s.writer, []byte(lengthAndMsg)) +} + +func writeFully(w io.Writer, data []byte) error { + for len(data) > 0 { + n, err := w.Write(data) + if err != nil { + return err + } + data = data[n:] + } + return nil +} + +var _ Sender = &realSender{} diff --git a/wire/sender_test.go b/wire/sender_test.go new file mode 100644 index 0000000..3b42b11 --- /dev/null +++ b/wire/sender_test.go @@ -0,0 +1,27 @@ +package wire + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWriteMessage(t *testing.T) { + s, b := NewTestSender() + err := SendMessageString(s, "hello") + assert.NoError(t, err) + assert.Equal(t, "0005hello", b.String()) +} + +func TestWriteEmptyMessage(t *testing.T) { + s, b := NewTestSender() + err := SendMessageString(s, "") + assert.NoError(t, err) + assert.Equal(t, "0000", b.String()) +} + +func NewTestSender() (Sender, *bytes.Buffer) { + var buf bytes.Buffer + return NewSender(&buf), &buf +}