diff --git a/client.go b/client.go index b4cc441..58ca3d7 100644 --- a/client.go +++ b/client.go @@ -8,6 +8,7 @@ import ( "crypto/ed25519" "crypto/rand" "encoding/json" + "errors" "fmt" "io" "net" @@ -134,97 +135,64 @@ func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code str } // Make a request to the API with the enrollment code - jv, err := json.Marshal(message.EnrollRequest{ + payload := message.EnrollRequest{ Code: code, NebulaPubkeyX25519: newKeys.NebulaX25519PublicKeyPEM, HostPubkeyEd25519: hostEd25519PublicKeyPEM, NebulaPubkeyP256: newKeys.NebulaP256PublicKeyPEM, HostPubkeyP256: hostP256PublicKeyPEM, Timestamp: time.Now(), - }) - if err != nil { - return nil, nil, nil, nil, err - } - - enrollURL, err := urlPath(c.dnServer, message.EnrollEndpoint) - if err != nil { - return nil, nil, nil, nil, err - } - - req, err := http.NewRequestWithContext(ctx, "POST", enrollURL, bytes.NewBuffer(jv)) - if err != nil { - return nil, nil, nil, nil, err - } - - resp, err := c.client.Do(req) - if err != nil { - return nil, nil, nil, nil, err - } - defer resp.Body.Close() - - // Log the request ID returned from the server - reqID := resp.Header.Get("X-Request-ID") - l := logger.WithFields(logrus.Fields{"statusCode": resp.StatusCode, "reqID": reqID}) - if resp.StatusCode == http.StatusOK { - l.Info("Enrollment request returned success code") - } else { - l.Error("Enrollment request returned error code") } - // Decode the response - r := message.APIResponse[message.EnrollResponseData]{} - b, err := io.ReadAll(resp.Body) + reqID, r, err := callAPI[message.EnrollResponseData](ctx, c, "POST", message.EnrollEndpoint, payload) + l := logger.WithFields(logrus.Fields{"reqID": reqID}) if err != nil { - return nil, nil, nil, nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID} - } - - if err := json.Unmarshal(b, &r); err != nil { - return nil, nil, nil, nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, b), ReqID: reqID} - } - - if len(r.Errors) == 1 { - // Check for *only* an "invalid code" error returned by the API - if err := r.Errors[0]; err.Path == "code" && err.Code == "ERR_INVALID_VALUE" { - return nil, nil, nil, nil, &APIError{e: ErrInvalidCode, ReqID: reqID} - } + var apiErrors message.APIErrors + if errors.As(err, &apiErrors) && len(apiErrors) == 1 { + // Check for *only* an "invalid code" error returned by the API + if err := apiErrors[0]; err.Path == "code" && err.Code == "ERR_INVALID_VALUE" { + l.Warn("Enrollment request failed for invalid code") + return nil, nil, nil, nil, &APIError{e: ErrInvalidCode, ReqID: reqID} + } - // Check for *only* a blocked host error returned by the API - if err := r.Errors[0]; err.Path == "" && err.Code == "ERR_HOST_BLOCKED" { - return nil, nil, nil, nil, &APIError{e: ErrHostBlocked, ReqID: reqID} + // Check for *only* a blocked host error returned by the API + if err := apiErrors[0]; err.Path == "" && err.Code == "ERR_HOST_BLOCKED" { + l.Warn("Enrollment request failed for blocked host") + return nil, nil, nil, nil, &APIError{e: ErrHostBlocked, ReqID: reqID} + } } - } - // Check for any errors returned by the API - if err := r.Errors.ToError(); err != nil { - return nil, nil, nil, nil, &APIError{e: fmt.Errorf("unexpected error during enrollment: %v", err), ReqID: reqID} + l.WithError(err).Error("Enrollment request failed with unexpected error") + return nil, nil, nil, nil, &APIError{e: fmt.Errorf("unexpected error during enrollment: %w", err), ReqID: reqID} } + l.Info("Enrollment request succeeded") meta := &ConfigMeta{ Org: ConfigOrg{ - ID: r.Data.Organization.ID, - Name: r.Data.Organization.Name, + ID: r.Organization.ID, + Name: r.Organization.Name, }, Network: ConfigNetwork{ - ID: r.Data.Network.ID, - Name: r.Data.Network.Name, + ID: r.Network.ID, + Name: r.Network.Name, }, Host: ConfigHost{ - ID: r.Data.HostID, - Name: r.Data.Host.Name, - IPAddress: r.Data.Host.IPAddress, + ID: r.HostID, + Name: r.Host.Name, + IPAddress: r.Host.IPAddress, }, } - if r.Data.EndpointOIDCMeta != nil { + if r.EndpointOIDCMeta != nil { meta.EndpointOIDC = &ConfigEndpointOIDC{ - Email: r.Data.EndpointOIDCMeta.Email, + Email: r.EndpointOIDCMeta.Email, } } // Determine the private keys to save based on the network curve type var privkeyPEM []byte var privkey keys.PrivateKey - switch r.Data.Network.Curve { + switch r.Network.Curve { case message.NetworkCurve25519: privkeyPEM = newKeys.NebulaX25519PrivateKeyPEM privkey = newKeys.HostEd25519PrivateKey @@ -232,21 +200,21 @@ func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code str privkeyPEM = newKeys.NebulaP256PrivateKeyPEM privkey = newKeys.HostP256PrivateKey default: - return nil, nil, nil, nil, &APIError{e: fmt.Errorf("unsupported curve type: %s", r.Data.Network.Curve), ReqID: reqID} + return nil, nil, nil, nil, &APIError{e: fmt.Errorf("unsupported curve type: %s", r.Network.Curve), ReqID: reqID} } - trustedKeys, err := keys.TrustedKeysFromPEM(r.Data.TrustedKeys) + trustedKeys, err := keys.TrustedKeysFromPEM(r.TrustedKeys) if err != nil { return nil, nil, nil, nil, &APIError{e: fmt.Errorf("failed to load trusted keys from bundle: %s", err), ReqID: reqID} } creds := &keys.Credentials{ - HostID: r.Data.HostID, + HostID: r.HostID, PrivateKey: privkey, - Counter: r.Data.Counter, + Counter: r.Counter, TrustedKeys: trustedKeys, } - return r.Data.Config, privkeyPEM, creds, meta, nil + return r.Config, privkeyPEM, creds, meta, nil } // CheckForUpdate sends a signed message to the DNClient API to learn if there is a new configuration available. @@ -514,12 +482,12 @@ func (c *Client) streamingPostDNClient(ctx context.Context, reqType string, valu sc.err.Store(ErrInvalidCredentials) default: var errors struct { - Errors message.APIErrors + Errors message.APIResponseErrors } if err := json.Unmarshal(respBody, &errors); err != nil { sc.err.Store(fmt.Errorf("dnclient endpoint returned bad status code '%d', body: %s", resp.StatusCode, respBody)) } else { - sc.err.Store(errors.Errors.ToError()) + sc.err.Store(errors.Errors.Err()) } } }() @@ -561,38 +529,39 @@ func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte, return nil, ErrInvalidCredentials default: var errors struct { - Errors message.APIErrors + Errors message.APIResponseErrors } if err := json.Unmarshal(respBody, &errors); err != nil { return nil, fmt.Errorf("dnclient endpoint returned bad status code '%d', body: %s", resp.StatusCode, respBody) } - return nil, errors.Errors.ToError() + return nil, errors.Errors.Err() } } -func callAPI[T any](ctx context.Context, c *Client, method string, endpoint string, payload map[string]any) (*T, error) { +// callAPI returns the request ID, requested response data, and any error if applicable. +func callAPI[T any](ctx context.Context, c *Client, method string, endpoint string, payload any) (string, *T, error) { dest, err := urlPath(c.dnServer, endpoint) if err != nil { - return nil, err + return "", nil, err } var br io.Reader if payload != nil { b, err := json.Marshal(payload) if err != nil { - return nil, fmt.Errorf("failed to marshal payload: %s", err) + return "", nil, fmt.Errorf("failed to marshal payload: %s", err) } br = bytes.NewReader(b) } req, err := http.NewRequestWithContext(ctx, method, dest, br) if err != nil { - return nil, err + return "", nil, err } resp, err := c.client.Do(req) if err != nil { - return nil, err + return "", nil, err } defer resp.Body.Close() @@ -601,24 +570,24 @@ func callAPI[T any](ctx context.Context, c *Client, method string, endpoint stri r := message.APIResponse[T]{} b, err := io.ReadAll(resp.Body) if err != nil { - return nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID} + return reqID, nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID} } if err := json.Unmarshal(b, &r); err != nil { - return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, b), ReqID: reqID} + return reqID, nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, b), ReqID: reqID} } // Check for any errors returned by the API - if err := r.Errors.ToError(); err != nil { - return nil, &APIError{e: err, ReqID: reqID} + if err := r.Errors.Err(); err != nil { + return reqID, nil, &APIError{e: err, ReqID: reqID} } // If we didn't detect an error in the response, but received a 4XX or 5XX status code, return error if resp.StatusCode >= 400 { - return nil, &APIError{e: fmt.Errorf("received HTTP %d from API without error details\nbody: %s", resp.StatusCode, b), ReqID: reqID} + return reqID, nil, &APIError{e: fmt.Errorf("received HTTP %d from API without error details\nbody: %s", resp.StatusCode, b), ReqID: reqID} } - return &r.Data, nil + return reqID, &r.Data, nil } // StreamController is used for interacting with streaming requests to the API. @@ -694,12 +663,14 @@ func nonce() []byte { } func (c *Client) EndpointPreAuth(ctx context.Context) (*message.PreAuthData, error) { - return callAPI[message.PreAuthData](ctx, c, "POST", message.PreAuthEndpoint, nil) + _, d, err := callAPI[message.PreAuthData](ctx, c, "POST", message.PreAuthEndpoint, nil) + return d, err } func (c *Client) EndpointAuthPoll(ctx context.Context, pollCode string) (*message.EndpointAuthPollData, error) { - pollURL := fmt.Sprintf("%s?pollToken=%s", message.EndpointAuthPoll, url.QueryEscape(pollCode)) - return callAPI[message.EndpointAuthPollData](ctx, c, "GET", pollURL, nil) + pollURL := fmt.Sprintf("%s?pollToken=%s", message.AuthPollEndpoint, url.QueryEscape(pollCode)) + _, d, err := callAPI[message.EndpointAuthPollData](ctx, c, "GET", pollURL, nil) + return d, err } func urlPath(base, path string) (string, error) { diff --git a/client_test.go b/client_test.go index 311ea5a..ab662df 100644 --- a/client_test.go +++ b/client_test.go @@ -65,7 +65,7 @@ func TestEnroll(t *testing.T) { }) if err != nil { return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ - Errors: message.APIErrors{{ + Errors: message.APIResponseErrors{{ Code: "ERR_FAILED_TO_MARSHAL_YAML", Message: "failed to marshal test response config", }}, @@ -149,7 +149,7 @@ func TestEnroll(t *testing.T) { errorMsg := "invalid enrollment code" ts.ExpectEnrollment(code, message.NetworkCurve25519, func(req message.EnrollRequest) []byte { return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ - Errors: message.APIErrors{{ + Errors: message.APIResponseErrors{{ Code: "ERR_INVALID_ENROLLMENT_CODE", Message: errorMsg, }}, @@ -194,7 +194,7 @@ func TestDoUpdate(t *testing.T) { }) if err != nil { return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ - Errors: message.APIErrors{{ + Errors: message.APIResponseErrors{{ Code: "ERR_FAILED_TO_MARSHAL_YAML", Message: "failed to marshal test response config", }}, @@ -463,7 +463,7 @@ func TestDoUpdate_P256(t *testing.T) { }) if err != nil { return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ - Errors: message.APIErrors{{ + Errors: message.APIResponseErrors{{ Code: "ERR_FAILED_TO_MARSHAL_YAML", Message: "failed to marshal test response config", }}, @@ -557,7 +557,7 @@ func TestDoUpdate_P256(t *testing.T) { sig, err := nk.HostP256PrivateKey.Sign(rawRes) if err != nil { return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ - Errors: message.APIErrors{{ + Errors: message.APIResponseErrors{{ Code: "ERR_FAILED_TO_SIGN_MESSAGE", Message: "failed to sign message", }}, @@ -601,7 +601,7 @@ func TestDoUpdate_P256(t *testing.T) { sig, err := ecdsa.SignASN1(rand.Reader, caPrivkey, hashed[:]) if err != nil { return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ - Errors: message.APIErrors{{ + Errors: message.APIResponseErrors{{ Code: "ERR_FAILED_TO_SIGN_MESSAGE", Message: "failed to sign message", }}, @@ -655,7 +655,7 @@ func TestDoUpdate_P256(t *testing.T) { sig, err := ecdsa.SignASN1(rand.Reader, caPrivkey, hashed[:]) if err != nil { return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ - Errors: message.APIErrors{{ + Errors: message.APIResponseErrors{{ Code: "ERR_FAILED_TO_SIGN_MESSAGE", Message: "failed to sign message", }}, @@ -703,7 +703,7 @@ func TestCommandResponse(t *testing.T) { }) if err != nil { return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ - Errors: message.APIErrors{{ + Errors: message.APIResponseErrors{{ Code: "ERR_FAILED_TO_MARSHAL_YAML", Message: "failed to marshal test response config", }}, @@ -774,7 +774,7 @@ func TestCommandResponse(t *testing.T) { errorMsg := "sample error" ts.ExpectDNClientRequest(message.CommandResponse, http.StatusBadRequest, func(r message.RequestWrapper) []byte { return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ - Errors: message.APIErrors{{ + Errors: message.APIResponseErrors{{ Code: "ERR_INVALID_VALUE", Message: errorMsg, }}, @@ -808,7 +808,7 @@ func TestStreamCommandResponse(t *testing.T) { }) if err != nil { return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ - Errors: message.APIErrors{{ + Errors: message.APIResponseErrors{{ Code: "ERR_FAILED_TO_MARSHAL_YAML", Message: "failed to marshal test response config", }}, @@ -885,7 +885,7 @@ func TestStreamCommandResponse(t *testing.T) { errorMsg := "sample error" ts.ExpectStreamingRequest(message.CommandResponse, http.StatusBadRequest, func(r message.RequestWrapper) []byte { return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ - Errors: message.APIErrors{{ + Errors: message.APIResponseErrors{{ Code: "ERR_INVALID_VALUE", Message: errorMsg, }}, @@ -934,7 +934,7 @@ func TestReauthenticate(t *testing.T) { }) if err != nil { return jsonMarshal(message.APIResponse[message.EnrollResponseData]{ - Errors: message.APIErrors{{ + Errors: message.APIResponseErrors{{ Code: "ERR_FAILED_TO_MARSHAL_YAML", Message: "failed to marshal test response config", }}, @@ -1094,7 +1094,7 @@ func TestGetOidcPollCode(t *testing.T) { //unhappy path ts.ExpectAPIRequest(http.StatusInternalServerError, func(req any) []byte { return jsonMarshal(message.APIResponse[message.PreAuthData]{ - Errors: message.APIErrors{{ + Errors: message.APIResponseErrors{{ Code: "ERR_INTERNAL_SERVER_ERROR", Message: "internal server error", }}, diff --git a/dnapitest/dnapitest.go b/dnapitest/dnapitest.go index 47364f0..249add5 100644 --- a/dnapitest/dnapitest.go +++ b/dnapitest/dnapitest.go @@ -72,7 +72,7 @@ func (s *Server) handler(w http.ResponseWriter, r *http.Request) { s.expectedRequests = s.expectedRequests[1:] w.WriteHeader(expected.StatusCode()) _, _ = w.Write(expected.Respond(nil)) - case message.EndpointAuthPoll: + case message.AuthPollEndpoint: s.handlerDoOidcPoll(w, r) default: s.errors = append(s.errors, fmt.Errorf("invalid request path %s", r.URL.Path)) diff --git a/message/message.go b/message/message.go index b5d1d4e..dab9c90 100644 --- a/message/message.go +++ b/message/message.go @@ -132,24 +132,36 @@ type ReauthenticateResponse struct { // APIResponse is a standard format for the DN API. It does not apply to the DNClient API. type APIResponse[T any] struct { - Data T `json:"data"` - Errors APIErrors `json:"errors"` + Data T `json:"data"` + Errors APIResponseErrors `json:"errors"` } -// APIError represents a single error returned in an API error response. -type APIError struct { +// APIResponseError represents a single error returned in an API error response. +type APIResponseError struct { Code string `json:"code"` Message string `json:"message"` Path string `json:"path"` // may or may not be present } +// APIResponseErrors is used to parse errors but is not a Golang error itself. +// It may or may not contain actual errors - if it doesn't, it should not be +// converted to an error. This should not be returned from the dnapi package. +type APIResponseErrors []APIResponseError + +func (m APIResponseErrors) Err() error { + if len(m) > 0 { + return APIErrors(m) + } + return nil +} + // APIErrors facilitates converting multiple API errors into a single Golang // error to be returned to callers. -type APIErrors []APIError +type APIErrors APIResponseErrors -func (errs APIErrors) ToError() error { - if len(errs) == 0 { - return nil +func (errs APIErrors) Error() string { + if len(errs) == 0 { // this shouldn't happen + panic("no errors") } s := make([]string, len(errs)) @@ -157,7 +169,7 @@ func (errs APIErrors) ToError() error { s[i] = errs[i].Message } - return errors.New(strings.Join(s, ", ")) + return strings.Join(s, ", ") } // EnrollEndpoint is the REST enrollment endpoint. @@ -219,7 +231,7 @@ type PreAuthData struct { LoginURL string `json:"loginURL"` } -const EndpointAuthPoll = "/v1/endpoint-auth/poll" +const AuthPollEndpoint = "/v1/endpoint-auth/poll" type EndpointAuthState string