diff --git a/cmd/demo/demo.go b/cmd/demo/demo.go index 5a35cd2..ec4064f 100644 --- a/cmd/demo/demo.go +++ b/cmd/demo/demo.go @@ -10,12 +10,16 @@ import ( adb "github.com/zach-klippenstein/goadb" ) -var port = flag.Int("p", goadb.AdbPort, "") +var port = flag.Int("p", adb.AdbPort, "") func main() { flag.Parse() - client := adb.NewHostClientPort(*port) + client, err := adb.NewHostClientPort(*port) + if err != nil { + log.Fatal(err) + } + fmt.Println("Starting server…") client.StartServer() @@ -100,6 +104,14 @@ func PrintDeviceInfo(device *adb.DeviceClient) error { } } + fmt.Println("\tnon-existent file:") + stat, err = device.Stat("/supercalifragilisticexpialidocious") + if err != nil { + fmt.Println("\terror:", err) + } else { + fmt.Printf("\tstat: %+v\n", stat) + } + fmt.Print("\tload avg: ") loadavgReader, err := device.OpenRead("/proc/loadavg") if err != nil { diff --git a/device_client.go b/device_client.go index dcffa42..8a3d503 100644 --- a/device_client.go +++ b/device_client.go @@ -10,7 +10,7 @@ import ( // DeviceClient communicates with a specific Android device. type DeviceClient struct { - dialer nilSafeDialer + dialer wire.Dialer descriptor *DeviceDescriptor } diff --git a/device_client_test.go b/device_client_test.go index 37131d3..762d6d7 100644 --- a/device_client_test.go +++ b/device_client_test.go @@ -12,7 +12,7 @@ func TestGetAttribute(t *testing.T) { Status: wire.StatusSuccess, Messages: []string{"value"}, } - client := &DeviceClient{nilSafeDialer{s}, deviceWithSerial("serial")} + client := &DeviceClient{s, deviceWithSerial("serial")} v, err := client.getAttribute("attr") assert.Equal(t, "host-serial:serial:attr", s.Requests[0]) @@ -25,7 +25,7 @@ func TestRunCommandNoArgs(t *testing.T) { Status: wire.StatusSuccess, Messages: []string{"output"}, } - client := &DeviceClient{nilSafeDialer{s}, anyDevice()} + client := &DeviceClient{s, anyDevice()} v, err := client.RunCommand("cmd") assert.Equal(t, "host:transport-any", s.Requests[0]) diff --git a/host_client.go b/host_client.go index 8497845..18d80ba 100644 --- a/host_client.go +++ b/host_client.go @@ -2,6 +2,7 @@ package goadb import ( + "errors" "os/exec" "strconv" @@ -26,19 +27,22 @@ See list of services at https://android.googlesource.com/platform/system/core/+/ */ // TODO(z): Finish implementing host services. type HostClient struct { - dialer nilSafeDialer + dialer wire.Dialer } -func NewHostClient() *HostClient { - return NewHostClientDialer(nil) +func NewHostClient() (*HostClient, error) { + return NewHostClientPort(AdbPort) } -func NewHostClientPort(port int) *HostClient { - return NewHostClientDialer(wire.NewDialer("", port)) +func NewHostClientPort(port int) (*HostClient, error) { + return NewHostClientDialer(wire.NewDialer("localhost", port)) } -func NewHostClientDialer(d wire.Dialer) *HostClient { - return &HostClient{nilSafeDialer{d}} +func NewHostClientDialer(d wire.Dialer) (*HostClient, error) { + if d == nil { + return nil, errors.New("dialer cannot be nil.") + } + return &HostClient{d}, nil } // GetServerVersion asks the ADB server for its internal version number. diff --git a/host_client_test.go b/host_client_test.go index 7638f85..c50e63d 100644 --- a/host_client_test.go +++ b/host_client_test.go @@ -14,7 +14,8 @@ func TestGetServerVersion(t *testing.T) { Status: wire.StatusSuccess, Messages: []string{"000a"}, } - client := NewHostClientDialer(s) + client, err := NewHostClientDialer(s) + assert.NoError(t, err) v, err := client.GetServerVersion() assert.Equal(t, "host:version", s.Requests[0]) diff --git a/wire/dialer.go b/wire/dialer.go index 57b0e23..6473ead 100644 --- a/wire/dialer.go +++ b/wire/dialer.go @@ -1,6 +1,7 @@ package wire import ( + "errors" "fmt" "net" "runtime" @@ -27,10 +28,15 @@ func NewDialer(host string, port int) Dialer { func (d *netDialer) Dial() (*Conn, error) { host := d.Host if host == "" { - host = "localhost" + return nil, errors.New("Must specify adb hostname (cannot be empty).") } - netConn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", host, d.Port)) + port := d.Port + if port == 0 { + return nil, errors.New("Must specify port (cannot be 0).") + } + + netConn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", host, port)) if err != nil { return nil, err }