diff --git a/client_config.go b/client_config.go new file mode 100644 index 0000000..b6867fe --- /dev/null +++ b/client_config.go @@ -0,0 +1,16 @@ +package goadb + +var ( + defaultDialer Dialer = NewDialer("", 0) +) + +type ClientConfig struct { + Dialer Dialer +} + +func (c ClientConfig) sanitized() ClientConfig { + if c.Dialer == nil { + c.Dialer = defaultDialer + } + return c +} diff --git a/cmd/demo/demo.go b/cmd/demo/demo.go index ec4064f..61ef9f2 100644 --- a/cmd/demo/demo.go +++ b/cmd/demo/demo.go @@ -15,13 +15,10 @@ var port = flag.Int("p", adb.AdbPort, "") func main() { flag.Parse() - client, err := adb.NewHostClientPort(*port) - if err != nil { - log.Fatal(err) - } + client := adb.NewHostClient(adb.ClientConfig{}) fmt.Println("Starting server…") - client.StartServer() + adb.StartServer() serverVersion, err := client.GetServerVersion() if err != nil { @@ -38,23 +35,24 @@ func main() { fmt.Printf("\t%+v\n", *device) } - PrintDeviceInfoAndError(client.GetAnyDevice()) - PrintDeviceInfoAndError(client.GetLocalDevice()) - PrintDeviceInfoAndError(client.GetUsbDevice()) + PrintDeviceInfoAndError(adb.AnyDevice()) + PrintDeviceInfoAndError(adb.AnyLocalDevice()) + PrintDeviceInfoAndError(adb.AnyUsbDevice()) serials, err := client.ListDeviceSerials() if err != nil { log.Fatal(err) } for _, serial := range serials { - PrintDeviceInfoAndError(client.GetDeviceWithSerial(serial)) + PrintDeviceInfoAndError(adb.DeviceWithSerial(serial)) } //fmt.Println("Killing server…") //client.KillServer() } -func PrintDeviceInfoAndError(device *adb.DeviceClient) { +func PrintDeviceInfoAndError(descriptor adb.DeviceDescriptor) { + device := adb.NewDeviceClient(adb.ClientConfig{}, descriptor) if err := PrintDeviceInfo(device); err != nil { log.Println(err) } diff --git a/cmd/raw-adb/raw-adb.go b/cmd/raw-adb/raw-adb.go index c7fc4df..231e830 100644 --- a/cmd/raw-adb/raw-adb.go +++ b/cmd/raw-adb/raw-adb.go @@ -49,7 +49,7 @@ func readLine() string { } func doCommand(cmd string) error { - conn, err := wire.NewDialer("", *port).Dial() + conn, err := goadb.NewDialer("", *port).Dial() if err != nil { log.Fatal(err) } diff --git a/device_client.go b/device_client.go index ea5961c..f02f196 100644 --- a/device_client.go +++ b/device_client.go @@ -10,8 +10,15 @@ import ( // DeviceClient communicates with a specific Android device. type DeviceClient struct { - dialer wire.Dialer - descriptor *DeviceDescriptor + config ClientConfig + descriptor DeviceDescriptor +} + +func NewDeviceClient(config ClientConfig, descriptor DeviceDescriptor) *DeviceClient { + return &DeviceClient{ + config: config.sanitized(), + descriptor: descriptor, + } } func (c *DeviceClient) String() string { @@ -134,7 +141,7 @@ func (c *DeviceClient) OpenRead(path string) (io.ReadCloser, error) { // getAttribute returns the first message returned by the server by running // :, where host-prefix is determined from the DeviceDescriptor. func (c *DeviceClient) getAttribute(attr string) (string, error) { - resp, err := wire.RoundTripSingleResponse(c.dialer, + resp, err := roundTripSingleResponse(c.config.Dialer, fmt.Sprintf("%s:%s", c.descriptor.getHostPrefix(), attr)) if err != nil { return "", err @@ -145,9 +152,9 @@ func (c *DeviceClient) getAttribute(attr string) (string, error) { // dialDevice switches the connection to communicate directly with the device // by requesting the transport defined by the DeviceDescriptor. func (c *DeviceClient) dialDevice() (*wire.Conn, error) { - conn, err := c.dialer.Dial() + conn, err := c.config.Dialer.Dial() if err != nil { - return nil, fmt.Errorf("error dialing adb server (%s): %+v", c.dialer, err) + return nil, fmt.Errorf("error dialing adb server (%s): %+v", c.config.Dialer, err) } req := fmt.Sprintf("host:%s", c.descriptor.getTransportDescriptor()) @@ -197,7 +204,7 @@ func prepareCommandLine(cmd string, args ...string) (string, error) { } } - // Prepend the comand to the args array. + // Prepend the command to the args array. if len(args) > 0 { cmd = fmt.Sprintf("%s %s", cmd, strings.Join(args, " ")) } diff --git a/device_client_test.go b/device_client_test.go index 762d6d7..18b8a52 100644 --- a/device_client_test.go +++ b/device_client_test.go @@ -12,7 +12,12 @@ func TestGetAttribute(t *testing.T) { Status: wire.StatusSuccess, Messages: []string{"value"}, } - client := &DeviceClient{s, deviceWithSerial("serial")} + client := NewDeviceClient( + ClientConfig{ + Dialer: s, + }, + DeviceWithSerial("serial"), + ) v, err := client.getAttribute("attr") assert.Equal(t, "host-serial:serial:attr", s.Requests[0]) @@ -25,7 +30,12 @@ func TestRunCommandNoArgs(t *testing.T) { Status: wire.StatusSuccess, Messages: []string{"output"}, } - client := &DeviceClient{s, anyDevice()} + client := NewDeviceClient( + ClientConfig{ + Dialer: s, + }, + AnyDevice(), + ) v, err := client.RunCommand("cmd") assert.Equal(t, "host:transport-any", s.Requests[0]) diff --git a/device_descriptor.go b/device_descriptor.go index 739f7d8..db78b14 100644 --- a/device_descriptor.go +++ b/device_descriptor.go @@ -23,33 +23,33 @@ type DeviceDescriptor struct { serial string } -func anyDevice() *DeviceDescriptor { - return &DeviceDescriptor{descriptorType: DeviceAny} +func AnyDevice() DeviceDescriptor { + return DeviceDescriptor{descriptorType: DeviceAny} } -func anyUsbDevice() *DeviceDescriptor { - return &DeviceDescriptor{descriptorType: DeviceUsb} +func AnyUsbDevice() DeviceDescriptor { + return DeviceDescriptor{descriptorType: DeviceUsb} } -func anyLocalDevice() *DeviceDescriptor { - return &DeviceDescriptor{descriptorType: DeviceLocal} +func AnyLocalDevice() DeviceDescriptor { + return DeviceDescriptor{descriptorType: DeviceLocal} } -func deviceWithSerial(serial string) *DeviceDescriptor { - return &DeviceDescriptor{ +func DeviceWithSerial(serial string) DeviceDescriptor { + return DeviceDescriptor{ descriptorType: DeviceSerial, serial: serial, } } -func (d *DeviceDescriptor) String() string { +func (d DeviceDescriptor) String() string { if d.descriptorType == DeviceSerial { return fmt.Sprintf("%s[%s]", d.descriptorType, d.serial) } return d.descriptorType.String() } -func (d *DeviceDescriptor) getHostPrefix() string { +func (d DeviceDescriptor) getHostPrefix() string { switch d.descriptorType { case DeviceAny: return "host" @@ -64,7 +64,7 @@ func (d *DeviceDescriptor) getHostPrefix() string { } } -func (d *DeviceDescriptor) getTransportDescriptor() string { +func (d DeviceDescriptor) getTransportDescriptor() string { switch d.descriptorType { case DeviceAny: return "transport-any" diff --git a/wire/dialer.go b/dialer.go similarity index 65% rename from wire/dialer.go rename to dialer.go index cb21431..c60a2c0 100644 --- a/wire/dialer.go +++ b/dialer.go @@ -1,17 +1,38 @@ -package wire +package goadb import ( - "errors" "fmt" "net" "runtime" + + "github.com/zach-klippenstein/goadb/wire" +) + +const ( + // Default port the adb server listens on. + AdbPort = 5037 ) /* Dialer knows how to create connections to an adb server. */ type Dialer interface { - Dial() (*Conn, error) + Dial() (*wire.Conn, error) +} + +/* +NewDialer creates a new Dialer. + +If host is "" or port is 0, "localhost:5037" is used. +*/ +func NewDialer(host string, port int) Dialer { + if host == "" { + host = "localhost" + } + if port == 0 { + port = AdbPort + } + return &netDialer{host, port} } type netDialer struct { @@ -19,35 +40,24 @@ type netDialer struct { Port int } -func NewDialer(host string, port int) Dialer { - return &netDialer{host, port} -} - func (d *netDialer) String() string { return fmt.Sprintf("netDialer(%s:%d)", d.Host, d.Port) } // Dial connects to the adb server on the host and port set on the netDialer. // The zero-value will connect to the default, localhost:5037. -func (d *netDialer) Dial() (*Conn, error) { +func (d *netDialer) Dial() (*wire.Conn, error) { host := d.Host - if host == "" { - return nil, errors.New("Must specify adb hostname (cannot be empty).") - } - port := d.Port - if port == 0 { - return nil, errors.New("Must specify port (cannot be 0).") - } netConn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", host, port)) if err != nil { return nil, err } - conn := &Conn{ - Scanner: NewScanner(netConn), - Sender: NewSender(netConn), + conn := &wire.Conn{ + Scanner: wire.NewScanner(netConn), + Sender: wire.NewSender(netConn), } // Prevent leaking the network connection, not sure if TCPConn does this itself. @@ -58,7 +68,8 @@ func (d *netDialer) Dial() (*Conn, error) { return conn, nil } -func RoundTripSingleResponse(d Dialer, req string) ([]byte, error) { +// TODO(zach): Make this unexported. +func roundTripSingleResponse(d Dialer, req string) ([]byte, error) { conn, err := d.Dial() if err != nil { return nil, err diff --git a/host_client.go b/host_client.go index 18d80ba..a9a9cad 100644 --- a/host_client.go +++ b/host_client.go @@ -2,18 +2,11 @@ package goadb import ( - "errors" - "os/exec" "strconv" "github.com/zach-klippenstein/goadb/wire" ) -const ( - // Default port the adb server listens on. - AdbPort = 5037 -) - /* HostClient communicates with host services on the adb server. @@ -27,27 +20,31 @@ See list of services at https://android.googlesource.com/platform/system/core/+/ */ // TODO(z): Finish implementing host services. type HostClient struct { - dialer wire.Dialer + config ClientConfig } -func NewHostClient() (*HostClient, error) { - return NewHostClientPort(AdbPort) -} +// func NewHostClient() (*HostClient, error) { +// return NewHostClientPort(AdbPort) +// } -func NewHostClientPort(port int) (*HostClient, error) { - return NewHostClientDialer(wire.NewDialer("localhost", port)) -} +// func NewHostClientPort(port int) (*HostClient, error) { +// return NewHostClientDialer(wire.NewDialer("localhost", port)) +// } -func NewHostClientDialer(d wire.Dialer) (*HostClient, error) { - if d == nil { - return nil, errors.New("dialer cannot be nil.") - } - return &HostClient{d}, nil +// func NewHostClientDialer(d wire.Dialer) (*HostClient, error) { +// if d == nil { +// return nil, errors.New("dialer cannot be nil.") +// } +// return &HostClient{d}, nil +// } + +func NewHostClient(config ClientConfig) *HostClient { + return &HostClient{config.sanitized()} } // GetServerVersion asks the ADB server for its internal version number. func (c *HostClient) GetServerVersion() (int, error) { - resp, err := wire.RoundTripSingleResponse(c.dialer, "host:version") + resp, err := roundTripSingleResponse(c.config.Dialer, "host:version") if err != nil { return 0, err } @@ -63,7 +60,7 @@ Corresponds to the command: adb kill-server */ func (c *HostClient) KillServer() error { - conn, err := c.dialer.Dial() + conn, err := c.config.Dialer.Dial() if err != nil { return err } @@ -76,17 +73,6 @@ func (c *HostClient) KillServer() error { return nil } -/* -StartServer ensures there is a server running. - -Currently implemented by just running - adb start-server -*/ -func (c *HostClient) StartServer() error { - cmd := exec.Command("adb", "start-server") - return cmd.Run() -} - /* ListDeviceSerials returns the serial numbers of all attached devices. @@ -94,7 +80,7 @@ Corresponds to the command: adb devices */ func (c *HostClient) ListDeviceSerials() ([]string, error) { - resp, err := wire.RoundTripSingleResponse(c.dialer, "host:devices") + resp, err := roundTripSingleResponse(c.config.Dialer, "host:devices") if err != nil { return nil, err } @@ -118,41 +104,10 @@ Corresponds to the command: adb devices -l */ func (c *HostClient) ListDevices() ([]*DeviceInfo, error) { - resp, err := wire.RoundTripSingleResponse(c.dialer, "host:devices-l") + resp, err := roundTripSingleResponse(c.config.Dialer, "host:devices-l") if err != nil { return nil, err } return parseDeviceList(string(resp), parseDeviceLong) } - -func (c *HostClient) GetDevice(d *DeviceInfo) *DeviceClient { - return c.GetDeviceWithSerial(d.Serial) -} - -// GetDeviceWithSerial returns a client for the device with the specified serial number. -// Will return a client even if there is no matching device connected. -func (c *HostClient) GetDeviceWithSerial(serial string) *DeviceClient { - return c.getDevice(deviceWithSerial(serial)) -} - -// GetAnyDevice returns a client for any one connected device. -func (c *HostClient) GetAnyDevice() *DeviceClient { - return c.getDevice(anyDevice()) -} - -// GetUsbDevice returns a client for the USB device. -// Will return a client even if there is no device connected. -func (c *HostClient) GetUsbDevice() *DeviceClient { - return c.getDevice(anyUsbDevice()) -} - -// GetLocalDevice returns a client for the local device. -// Will return a client even if there is no device connected. -func (c *HostClient) GetLocalDevice() *DeviceClient { - return c.getDevice(anyLocalDevice()) -} - -func (c *HostClient) getDevice(descriptor *DeviceDescriptor) *DeviceClient { - return &DeviceClient{c.dialer, descriptor} -} diff --git a/host_client_test.go b/host_client_test.go index c50e63d..7f26b1d 100644 --- a/host_client_test.go +++ b/host_client_test.go @@ -14,8 +14,9 @@ func TestGetServerVersion(t *testing.T) { Status: wire.StatusSuccess, Messages: []string{"000a"}, } - client, err := NewHostClientDialer(s) - assert.NoError(t, err) + client := NewHostClient(ClientConfig{ + Dialer: s, + }) v, err := client.GetServerVersion() assert.Equal(t, "host:version", s.Requests[0]) diff --git a/nil_safe_dialer.go b/nil_safe_dialer.go deleted file mode 100644 index 3d51349..0000000 --- a/nil_safe_dialer.go +++ /dev/null @@ -1,15 +0,0 @@ -package goadb - -import "github.com/zach-klippenstein/goadb/wire" - -type nilSafeDialer struct { - wire.Dialer -} - -func (d nilSafeDialer) Dial() (*wire.Conn, error) { - if d.Dialer == nil { - d.Dialer = wire.NewDialer("", AdbPort) - } - - return d.Dialer.Dial() -} diff --git a/server_controller.go b/server_controller.go new file mode 100644 index 0000000..c9cd01e --- /dev/null +++ b/server_controller.go @@ -0,0 +1,14 @@ +package goadb + +import "os/exec" + +/* +StartServer ensures there is a server running. + +Currently implemented by just running + adb start-server +*/ +func StartServer() error { + cmd := exec.Command("adb", "start-server") + return cmd.Run() +} diff --git a/wire/adb_error.go b/wire/adb_server_error.go similarity index 80% rename from wire/adb_error.go rename to wire/adb_server_error.go index dddaccd..df79b28 100644 --- a/wire/adb_error.go +++ b/wire/adb_server_error.go @@ -4,14 +4,14 @@ import ( "fmt" ) -type AdbError struct { +type AdbServerError struct { Request string ServerMsg string } -var _ error = &AdbError{} +var _ error = &AdbServerError{} -func (e *AdbError) Error() string { +func (e *AdbServerError) Error() string { if e.Request == "" { return fmt.Sprintf("server error: %s", e.ServerMsg) } else { diff --git a/wire/util.go b/wire/util.go index f0c00c9..cddb3f5 100644 --- a/wire/util.go +++ b/wire/util.go @@ -20,7 +20,7 @@ func ReadStatusFailureAsError(s Scanner, req string) error { return fmt.Errorf("server returned error for %s, but couldn't read the error message: %+v", err) } - return &AdbError{ + return &AdbServerError{ Request: req, ServerMsg: string(msg), }