Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 55 additions & 84 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"crypto/ed25519"
"crypto/rand"
"encoding/json"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -134,119 +135,86 @@ func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code str
}

// Make a request to the API with the enrollment code
jv, err := json.Marshal(message.EnrollRequest{
payload := message.EnrollRequest{
Code: code,
NebulaPubkeyX25519: newKeys.NebulaX25519PublicKeyPEM,
HostPubkeyEd25519: hostEd25519PublicKeyPEM,
NebulaPubkeyP256: newKeys.NebulaP256PublicKeyPEM,
HostPubkeyP256: hostP256PublicKeyPEM,
Timestamp: time.Now(),
})
if err != nil {
return nil, nil, nil, nil, err
}

enrollURL, err := urlPath(c.dnServer, message.EnrollEndpoint)
if err != nil {
return nil, nil, nil, nil, err
}

req, err := http.NewRequestWithContext(ctx, "POST", enrollURL, bytes.NewBuffer(jv))
if err != nil {
return nil, nil, nil, nil, err
}

resp, err := c.client.Do(req)
if err != nil {
return nil, nil, nil, nil, err
}
defer resp.Body.Close()

// Log the request ID returned from the server
reqID := resp.Header.Get("X-Request-ID")
l := logger.WithFields(logrus.Fields{"statusCode": resp.StatusCode, "reqID": reqID})
if resp.StatusCode == http.StatusOK {
l.Info("Enrollment request returned success code")
} else {
l.Error("Enrollment request returned error code")
}

// Decode the response
r := message.APIResponse[message.EnrollResponseData]{}
b, err := io.ReadAll(resp.Body)
reqID, r, err := callAPI[message.EnrollResponseData](ctx, c, "POST", message.EnrollEndpoint, payload)
l := logger.WithFields(logrus.Fields{"reqID": reqID})
if err != nil {
return nil, nil, nil, nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID}
}

if err := json.Unmarshal(b, &r); err != nil {
return nil, nil, nil, nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, b), ReqID: reqID}
}

if len(r.Errors) == 1 {
// Check for *only* an "invalid code" error returned by the API
if err := r.Errors[0]; err.Path == "code" && err.Code == "ERR_INVALID_VALUE" {
return nil, nil, nil, nil, &APIError{e: ErrInvalidCode, ReqID: reqID}
}
var apiErrors message.APIErrors
if errors.As(err, &apiErrors) && len(apiErrors) == 1 {
// Check for *only* an "invalid code" error returned by the API
if err := apiErrors[0]; err.Path == "code" && err.Code == "ERR_INVALID_VALUE" {
l.Warn("Enrollment request failed for invalid code")
return nil, nil, nil, nil, &APIError{e: ErrInvalidCode, ReqID: reqID}
}

// Check for *only* a blocked host error returned by the API
if err := r.Errors[0]; err.Path == "" && err.Code == "ERR_HOST_BLOCKED" {
return nil, nil, nil, nil, &APIError{e: ErrHostBlocked, ReqID: reqID}
// Check for *only* a blocked host error returned by the API
if err := apiErrors[0]; err.Path == "" && err.Code == "ERR_HOST_BLOCKED" {
l.Warn("Enrollment request failed for blocked host")
return nil, nil, nil, nil, &APIError{e: ErrHostBlocked, ReqID: reqID}
}
}
}

// Check for any errors returned by the API
if err := r.Errors.ToError(); err != nil {
return nil, nil, nil, nil, &APIError{e: fmt.Errorf("unexpected error during enrollment: %v", err), ReqID: reqID}
l.WithError(err).Error("Enrollment request failed with unexpected error")
return nil, nil, nil, nil, &APIError{e: fmt.Errorf("unexpected error during enrollment: %w", err), ReqID: reqID}
}
l.Info("Enrollment request succeeded")

meta := &ConfigMeta{
Org: ConfigOrg{
ID: r.Data.Organization.ID,
Name: r.Data.Organization.Name,
ID: r.Organization.ID,
Name: r.Organization.Name,
},
Network: ConfigNetwork{
ID: r.Data.Network.ID,
Name: r.Data.Network.Name,
ID: r.Network.ID,
Name: r.Network.Name,
},
Host: ConfigHost{
ID: r.Data.HostID,
Name: r.Data.Host.Name,
IPAddress: r.Data.Host.IPAddress,
ID: r.HostID,
Name: r.Host.Name,
IPAddress: r.Host.IPAddress,
},
}

if r.Data.EndpointOIDCMeta != nil {
if r.EndpointOIDCMeta != nil {
meta.EndpointOIDC = &ConfigEndpointOIDC{
Email: r.Data.EndpointOIDCMeta.Email,
Email: r.EndpointOIDCMeta.Email,
}
}

// Determine the private keys to save based on the network curve type
var privkeyPEM []byte
var privkey keys.PrivateKey
switch r.Data.Network.Curve {
switch r.Network.Curve {
case message.NetworkCurve25519:
privkeyPEM = newKeys.NebulaX25519PrivateKeyPEM
privkey = newKeys.HostEd25519PrivateKey
case message.NetworkCurveP256:
privkeyPEM = newKeys.NebulaP256PrivateKeyPEM
privkey = newKeys.HostP256PrivateKey
default:
return nil, nil, nil, nil, &APIError{e: fmt.Errorf("unsupported curve type: %s", r.Data.Network.Curve), ReqID: reqID}
return nil, nil, nil, nil, &APIError{e: fmt.Errorf("unsupported curve type: %s", r.Network.Curve), ReqID: reqID}
}

trustedKeys, err := keys.TrustedKeysFromPEM(r.Data.TrustedKeys)
trustedKeys, err := keys.TrustedKeysFromPEM(r.TrustedKeys)
if err != nil {
return nil, nil, nil, nil, &APIError{e: fmt.Errorf("failed to load trusted keys from bundle: %s", err), ReqID: reqID}
}

creds := &keys.Credentials{
HostID: r.Data.HostID,
HostID: r.HostID,
PrivateKey: privkey,
Counter: r.Data.Counter,
Counter: r.Counter,
TrustedKeys: trustedKeys,
}
return r.Data.Config, privkeyPEM, creds, meta, nil
return r.Config, privkeyPEM, creds, meta, nil
}

// CheckForUpdate sends a signed message to the DNClient API to learn if there is a new configuration available.
Expand Down Expand Up @@ -514,12 +482,12 @@ func (c *Client) streamingPostDNClient(ctx context.Context, reqType string, valu
sc.err.Store(ErrInvalidCredentials)
default:
var errors struct {
Errors message.APIErrors
Errors message.APIResponseErrors
}
if err := json.Unmarshal(respBody, &errors); err != nil {
sc.err.Store(fmt.Errorf("dnclient endpoint returned bad status code '%d', body: %s", resp.StatusCode, respBody))
} else {
sc.err.Store(errors.Errors.ToError())
sc.err.Store(errors.Errors.Err())
}
}
}()
Expand Down Expand Up @@ -561,38 +529,39 @@ func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte,
return nil, ErrInvalidCredentials
default:
var errors struct {
Errors message.APIErrors
Errors message.APIResponseErrors
}
if err := json.Unmarshal(respBody, &errors); err != nil {
return nil, fmt.Errorf("dnclient endpoint returned bad status code '%d', body: %s", resp.StatusCode, respBody)
}
return nil, errors.Errors.ToError()
return nil, errors.Errors.Err()
}
}

func callAPI[T any](ctx context.Context, c *Client, method string, endpoint string, payload map[string]any) (*T, error) {
// callAPI returns the request ID, requested response data, and any error if applicable.
func callAPI[T any](ctx context.Context, c *Client, method string, endpoint string, payload any) (string, *T, error) {
dest, err := urlPath(c.dnServer, endpoint)
if err != nil {
return nil, err
return "", nil, err
}

var br io.Reader
if payload != nil {
b, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal payload: %s", err)
return "", nil, fmt.Errorf("failed to marshal payload: %s", err)
}
br = bytes.NewReader(b)
}

req, err := http.NewRequestWithContext(ctx, method, dest, br)
if err != nil {
return nil, err
return "", nil, err
}

resp, err := c.client.Do(req)
if err != nil {
return nil, err
return "", nil, err
}
defer resp.Body.Close()

Expand All @@ -601,24 +570,24 @@ func callAPI[T any](ctx context.Context, c *Client, method string, endpoint stri
r := message.APIResponse[T]{}
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID}
return reqID, nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID}
}

if err := json.Unmarshal(b, &r); err != nil {
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, b), ReqID: reqID}
return reqID, nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, b), ReqID: reqID}
}

// Check for any errors returned by the API
if err := r.Errors.ToError(); err != nil {
return nil, &APIError{e: err, ReqID: reqID}
if err := r.Errors.Err(); err != nil {
return reqID, nil, &APIError{e: err, ReqID: reqID}
}

// If we didn't detect an error in the response, but received a 4XX or 5XX status code, return error
if resp.StatusCode >= 400 {
return nil, &APIError{e: fmt.Errorf("received HTTP %d from API without error details\nbody: %s", resp.StatusCode, b), ReqID: reqID}
return reqID, nil, &APIError{e: fmt.Errorf("received HTTP %d from API without error details\nbody: %s", resp.StatusCode, b), ReqID: reqID}
}

return &r.Data, nil
return reqID, &r.Data, nil
}

// StreamController is used for interacting with streaming requests to the API.
Expand Down Expand Up @@ -694,12 +663,14 @@ func nonce() []byte {
}

func (c *Client) EndpointPreAuth(ctx context.Context) (*message.PreAuthData, error) {
return callAPI[message.PreAuthData](ctx, c, "POST", message.PreAuthEndpoint, nil)
_, d, err := callAPI[message.PreAuthData](ctx, c, "POST", message.PreAuthEndpoint, nil)
return d, err
}

func (c *Client) EndpointAuthPoll(ctx context.Context, pollCode string) (*message.EndpointAuthPollData, error) {
pollURL := fmt.Sprintf("%s?pollToken=%s", message.EndpointAuthPoll, url.QueryEscape(pollCode))
return callAPI[message.EndpointAuthPollData](ctx, c, "GET", pollURL, nil)
pollURL := fmt.Sprintf("%s?pollToken=%s", message.AuthPollEndpoint, url.QueryEscape(pollCode))
_, d, err := callAPI[message.EndpointAuthPollData](ctx, c, "GET", pollURL, nil)
return d, err
}

func urlPath(base, path string) (string, error) {
Expand Down
26 changes: 13 additions & 13 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func TestEnroll(t *testing.T) {
})
if err != nil {
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Errors: message.APIResponseErrors{{
Code: "ERR_FAILED_TO_MARSHAL_YAML",
Message: "failed to marshal test response config",
}},
Expand Down Expand Up @@ -149,7 +149,7 @@ func TestEnroll(t *testing.T) {
errorMsg := "invalid enrollment code"
ts.ExpectEnrollment(code, message.NetworkCurve25519, func(req message.EnrollRequest) []byte {
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Errors: message.APIResponseErrors{{
Code: "ERR_INVALID_ENROLLMENT_CODE",
Message: errorMsg,
}},
Expand Down Expand Up @@ -194,7 +194,7 @@ func TestDoUpdate(t *testing.T) {
})
if err != nil {
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Errors: message.APIResponseErrors{{
Code: "ERR_FAILED_TO_MARSHAL_YAML",
Message: "failed to marshal test response config",
}},
Expand Down Expand Up @@ -463,7 +463,7 @@ func TestDoUpdate_P256(t *testing.T) {
})
if err != nil {
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Errors: message.APIResponseErrors{{
Code: "ERR_FAILED_TO_MARSHAL_YAML",
Message: "failed to marshal test response config",
}},
Expand Down Expand Up @@ -557,7 +557,7 @@ func TestDoUpdate_P256(t *testing.T) {
sig, err := nk.HostP256PrivateKey.Sign(rawRes)
if err != nil {
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Errors: message.APIResponseErrors{{
Code: "ERR_FAILED_TO_SIGN_MESSAGE",
Message: "failed to sign message",
}},
Expand Down Expand Up @@ -601,7 +601,7 @@ func TestDoUpdate_P256(t *testing.T) {
sig, err := ecdsa.SignASN1(rand.Reader, caPrivkey, hashed[:])
if err != nil {
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Errors: message.APIResponseErrors{{
Code: "ERR_FAILED_TO_SIGN_MESSAGE",
Message: "failed to sign message",
}},
Expand Down Expand Up @@ -655,7 +655,7 @@ func TestDoUpdate_P256(t *testing.T) {
sig, err := ecdsa.SignASN1(rand.Reader, caPrivkey, hashed[:])
if err != nil {
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Errors: message.APIResponseErrors{{
Code: "ERR_FAILED_TO_SIGN_MESSAGE",
Message: "failed to sign message",
}},
Expand Down Expand Up @@ -703,7 +703,7 @@ func TestCommandResponse(t *testing.T) {
})
if err != nil {
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Errors: message.APIResponseErrors{{
Code: "ERR_FAILED_TO_MARSHAL_YAML",
Message: "failed to marshal test response config",
}},
Expand Down Expand Up @@ -774,7 +774,7 @@ func TestCommandResponse(t *testing.T) {
errorMsg := "sample error"
ts.ExpectDNClientRequest(message.CommandResponse, http.StatusBadRequest, func(r message.RequestWrapper) []byte {
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Errors: message.APIResponseErrors{{
Code: "ERR_INVALID_VALUE",
Message: errorMsg,
}},
Expand Down Expand Up @@ -808,7 +808,7 @@ func TestStreamCommandResponse(t *testing.T) {
})
if err != nil {
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Errors: message.APIResponseErrors{{
Code: "ERR_FAILED_TO_MARSHAL_YAML",
Message: "failed to marshal test response config",
}},
Expand Down Expand Up @@ -885,7 +885,7 @@ func TestStreamCommandResponse(t *testing.T) {
errorMsg := "sample error"
ts.ExpectStreamingRequest(message.CommandResponse, http.StatusBadRequest, func(r message.RequestWrapper) []byte {
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Errors: message.APIResponseErrors{{
Code: "ERR_INVALID_VALUE",
Message: errorMsg,
}},
Expand Down Expand Up @@ -934,7 +934,7 @@ func TestReauthenticate(t *testing.T) {
})
if err != nil {
return jsonMarshal(message.APIResponse[message.EnrollResponseData]{
Errors: message.APIErrors{{
Errors: message.APIResponseErrors{{
Code: "ERR_FAILED_TO_MARSHAL_YAML",
Message: "failed to marshal test response config",
}},
Expand Down Expand Up @@ -1094,7 +1094,7 @@ func TestGetOidcPollCode(t *testing.T) {
//unhappy path
ts.ExpectAPIRequest(http.StatusInternalServerError, func(req any) []byte {
return jsonMarshal(message.APIResponse[message.PreAuthData]{
Errors: message.APIErrors{{
Errors: message.APIResponseErrors{{
Code: "ERR_INTERNAL_SERVER_ERROR",
Message: "internal server error",
}},
Expand Down
2 changes: 1 addition & 1 deletion dnapitest/dnapitest.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
s.expectedRequests = s.expectedRequests[1:]
w.WriteHeader(expected.StatusCode())
_, _ = w.Write(expected.Respond(nil))
case message.EndpointAuthPoll:
case message.AuthPollEndpoint:
s.handlerDoOidcPoll(w, r)
default:
s.errors = append(s.errors, fmt.Errorf("invalid request path %s", r.URL.Path))
Expand Down
Loading
Loading