Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 66 additions & 79 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code str
return nil, nil, nil, nil, err
}

enrollURL, err := url.JoinPath(c.dnServer, message.EnrollEndpoint)
enrollURL, err := urlPath(c.dnServer, message.EnrollEndpoint)
if err != nil {
return nil, nil, nil, nil, err
}
Expand All @@ -172,7 +172,7 @@ func (c *Client) Enroll(ctx context.Context, logger logrus.FieldLogger, code str
}

// Decode the response
r := message.EnrollResponse{}
r := message.APIResponse[message.EnrollResponseData]{}
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, nil, nil, nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID}
Expand Down Expand Up @@ -480,7 +480,7 @@ func (c *Client) streamingPostDNClient(ctx context.Context, reqType string, valu
}
pbb := bytes.NewBuffer(postBody)

endpointV1URL, err := url.JoinPath(c.dnServer, message.EndpointV1)
endpointV1URL, err := urlPath(c.dnServer, message.EndpointV1)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -535,7 +535,7 @@ func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte,
return nil, err
}

endpointV1URL, err := url.JoinPath(c.dnServer, message.EndpointV1)
endpointV1URL, err := urlPath(c.dnServer, message.EndpointV1)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -570,6 +570,57 @@ func (c *Client) postDNClient(ctx context.Context, reqType string, value []byte,
}
}

func callAPI[T any](ctx context.Context, c *Client, method string, endpoint string, payload map[string]any) (*T, error) {
dest, err := urlPath(c.dnServer, endpoint)
if err != nil {
return nil, err
}

var br io.Reader
if payload != nil {
b, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal payload: %s", err)
}
br = bytes.NewReader(b)
}

req, err := http.NewRequestWithContext(ctx, method, dest, br)
if err != nil {
return nil, err
}

resp, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

reqID := resp.Header.Get("X-Request-ID")

r := message.APIResponse[T]{}
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, &APIError{e: fmt.Errorf("error reading response body: %s", err), ReqID: reqID}
}

if err := json.Unmarshal(b, &r); err != nil {
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, b), ReqID: reqID}
}

// Check for any errors returned by the API
if err := r.Errors.ToError(); err != nil {
return nil, &APIError{e: err, ReqID: reqID}
}

// If we didn't detect an error in the response, but received a 4XX or 5XX status code, return error
if resp.StatusCode >= 400 {
return nil, &APIError{e: fmt.Errorf("received HTTP %d from API without error details\nbody: %s", resp.StatusCode, b), ReqID: reqID}
}

return &r.Data, nil
}

// StreamController is used for interacting with streaming requests to the API.
//
// When a streaming request is started in a background goroutine, a StreamController is returned to the caller to allow
Expand Down Expand Up @@ -643,89 +694,25 @@ func nonce() []byte {
}

func (c *Client) EndpointPreAuth(ctx context.Context) (*message.PreAuthData, error) {
dest, err := url.JoinPath(c.dnServer, message.PreAuthEndpoint)
if err != nil {
return nil, err
}

req, err := http.NewRequestWithContext(ctx, "POST", dest, nil)
if err != nil {
return nil, err
}

resp, err := c.client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()

reqID := resp.Header.Get("X-Request-ID")
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, &APIError{e: fmt.Errorf("failed to read the response body: %s", err), ReqID: reqID}
}

switch resp.StatusCode {
case http.StatusOK:
r := message.PreAuthResponse{}
if err = json.Unmarshal(respBody, &r); err != nil {
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, respBody), ReqID: reqID}
}

if r.Data.PollToken == "" || r.Data.LoginURL == "" {
return nil, &APIError{e: fmt.Errorf("missing pollToken or loginURL"), ReqID: reqID}
}

return &r.Data, nil
default:
var errors struct {
Errors message.APIErrors
}
if err := json.Unmarshal(respBody, &errors); err != nil {
return nil, fmt.Errorf("bad status code '%d', body: %s", resp.StatusCode, respBody)
}
return nil, &APIError{e: errors.Errors.ToError(), ReqID: reqID}
}
return callAPI[message.PreAuthData](ctx, c, "POST", message.PreAuthEndpoint, nil)
}

func (c *Client) EndpointAuthPoll(ctx context.Context, pollCode string) (*message.EndpointAuthPollData, error) {
pollURL, err := url.JoinPath(c.dnServer, message.EndpointAuthPoll)
if err != nil {
return nil, err
}
pollURL = fmt.Sprintf("%s?pollToken=%s", pollURL, url.QueryEscape(pollCode))

req, err := http.NewRequestWithContext(ctx, "GET", pollURL, nil)
if err != nil {
return nil, err
}
pollURL := fmt.Sprintf("%s?pollToken=%s", message.EndpointAuthPoll, url.QueryEscape(pollCode))
return callAPI[message.EndpointAuthPollData](ctx, c, "GET", pollURL, nil)
}

resp, err := c.client.Do(req)
func urlPath(base, path string) (string, error) {
baseURL, err := url.Parse(base)
if err != nil {
return nil, err
return "", fmt.Errorf("invalid base: %s", err)
}
defer resp.Body.Close()

reqID := resp.Header.Get("X-Request-ID")
respBody, err := io.ReadAll(resp.Body)
pathURL, err := url.Parse(path)
if err != nil {
return nil, &APIError{e: fmt.Errorf("failed to read the response body: %s", err), ReqID: reqID}
return "", fmt.Errorf("invalid path: %s", err)
}

switch resp.StatusCode {
case http.StatusOK:
r := message.EndpointAuthPollResponse{}
if err = json.Unmarshal(respBody, &r); err != nil {
return nil, &APIError{e: fmt.Errorf("error decoding JSON response: %s\nbody: %s", err, respBody), ReqID: reqID}
}
return &r.Data, nil
default:
var errors struct {
Errors message.APIErrors
}
if err := json.Unmarshal(respBody, &errors); err != nil {
return nil, fmt.Errorf("bad status code '%d', body: %s", resp.StatusCode, respBody)
}
return nil, &APIError{e: errors.Errors.ToError(), ReqID: reqID}
}
finalURL := baseURL.ResolveReference(pathURL)
return finalURL.String(), nil
}
Loading
Loading