Skip to content

Commit

Permalink
chore: add files UI to tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
ibuildthecloud committed Dec 17, 2024
1 parent 8e863f3 commit 8fa4f94
Show file tree
Hide file tree
Showing 22 changed files with 414 additions and 79 deletions.
19 changes: 15 additions & 4 deletions apiclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,22 @@ func (c *Client) putJSON(ctx context.Context, path string, obj any, headerKV ...
}

func (c *Client) postJSON(ctx context.Context, path string, obj any, headerKV ...string) (*http.Request, *http.Response, error) {
data, err := json.Marshal(obj)
if err != nil {
return nil, nil, err
var body io.Reader

switch v := obj.(type) {
case string:
if v != "" {
body = strings.NewReader(v)
}
default:
data, err := json.Marshal(obj)
if err != nil {
return nil, nil, err
}
body = bytes.NewBuffer(data)
headerKV = append(headerKV, "Content-Type", "application/json")
}
return c.doRequest(ctx, http.MethodPost, path, bytes.NewBuffer(data), append(headerKV, "Content-Type", "application/json")...)
return c.doRequest(ctx, http.MethodPost, path, body, headerKV...)
}

func (c *Client) doStream(ctx context.Context, method, path string, body io.Reader, headerKV ...string) (*http.Request, *http.Response, error) {
Expand Down
139 changes: 139 additions & 0 deletions apiclient/task.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package apiclient

import (
"context"
"fmt"
"net/http"
"sort"

"github.com/acorn-io/acorn/apiclient/types"
)

type ListTasksOptions struct {
ThreadID string
AssistantID string
}

func (c *Client) ListTasks(ctx context.Context, opts ListTasksOptions) (result types.TaskList, err error) {
defer func() {
sort.Slice(result.Items, func(i, j int) bool {
return result.Items[i].Metadata.Created.Time.Before(result.Items[j].Metadata.Created.Time)
})
}()

if opts.ThreadID == "" && opts.AssistantID == "" {
return result, fmt.Errorf("either threadID or assistantID must be provided")
}
var url string
if opts.ThreadID != "" {
url = fmt.Sprintf("/threads/%s/tasks", opts.ThreadID)
} else {
url = fmt.Sprintf("/assistants/%s/tasks", opts.AssistantID)
}
_, resp, err := c.doRequest(ctx, http.MethodGet, url, nil)
if err != nil {
return result, err
}
defer resp.Body.Close()

_, err = toObject(resp, &result)
return
}

type UpdateTaskOptions struct {
ThreadID string
AssistantID string
}

func (c *Client) UpdateTask(ctx context.Context, id string, manifest types.TaskManifest, opt UpdateTaskOptions) (*types.Task, error) {
var url string
if opt.ThreadID != "" {
url = fmt.Sprintf("/threads/%s/tasks/%s", opt.ThreadID, id)
} else {
url = fmt.Sprintf("/assistants/%s/tasks/%s", opt.AssistantID, id)
}

_, resp, err := c.putJSON(ctx, url, manifest)
if err != nil {
return nil, err
}
defer resp.Body.Close()

return toObject(resp, &types.Task{})
}

type CreateTaskOptions struct {
ThreadID string
AssistantID string
}

func (c *Client) CreateTask(ctx context.Context, manifest types.TaskManifest, opt CreateTaskOptions) (*types.Task, error) {
var url string
if opt.ThreadID != "" {
url = fmt.Sprintf("/threads/%s/tasks", opt.ThreadID)
} else {
url = fmt.Sprintf("/assistants/%s/tasks", opt.AssistantID)
}

_, resp, err := c.postJSON(ctx, url, manifest)
if err != nil {
return nil, err
}
defer resp.Body.Close()

return toObject(resp, &types.Task{})
}

type ListTaskRunsOptions struct {
ThreadID string
AssistantID string
}

func (c *Client) ListTaskRuns(ctx context.Context, taskID string, opts ListTaskRunsOptions) (result types.TaskRunList, err error) {
defer func() {
sort.Slice(result.Items, func(i, j int) bool {
return result.Items[i].Metadata.Created.Time.Before(result.Items[j].Metadata.Created.Time)
})
}()

if opts.ThreadID == "" && opts.AssistantID == "" {
return result, fmt.Errorf("either threadID or assistantID must be provided")
}
var url string
if opts.ThreadID != "" {
url = fmt.Sprintf("/threads/%s/tasks/%s/runs", opts.ThreadID, taskID)
} else {
url = fmt.Sprintf("/assistants/%s/tasks/%s/runs", opts.AssistantID, taskID)
}

_, resp, err := c.doRequest(ctx, http.MethodGet, url, nil)
if err != nil {
return
}
defer resp.Body.Close()

_, err = toObject(resp, &result)
return
}

type TaskRunOptions struct {
ThreadID string
AssistantID string
}

func (c *Client) RunTask(ctx context.Context, taskID string, input string, opts TaskRunOptions) (*types.TaskRun, error) {
var url string
if opts.ThreadID != "" {
url = fmt.Sprintf("/threads/%s/tasks/%s/runs", opts.ThreadID, taskID)
} else {
url = fmt.Sprintf("/assistants/%s/tasks/%s/runs", opts.AssistantID, taskID)
}

_, resp, err := c.postJSON(ctx, url, input)
if err != nil {
return nil, err
}
defer resp.Body.Close()

return toObject(resp, &types.TaskRun{})
}
2 changes: 1 addition & 1 deletion pkg/api/authz/thread.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func authorizeThread(req *http.Request, user user.Info) bool {
if req.Method == "GET" && strings.HasPrefix(req.URL.Path, "/api/threads/"+thread+"/") {
return true
}
if req.Method == "POST" && req.URL.Path == "/api/invoke/"+agent+"/"+"threads/"+thread {
if req.Method == "POST" && strings.HasPrefix(req.URL.Path, "/api/threads/"+thread+"/tasks/") {
return true
}

Expand Down
14 changes: 3 additions & 11 deletions pkg/api/handlers/assistants.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,7 @@ func (a *AssistantHandler) Events(req api.Context) error {
}

func (a *AssistantHandler) Files(req api.Context) error {
var (
id = req.PathValue("id")
)

thread, err := getUserThread(req, id)
thread, err := getThreadForScope(req)
if err != nil {
return err
}
Expand All @@ -258,17 +254,13 @@ func (a *AssistantHandler) Files(req api.Context) error {
}

func (a *AssistantHandler) GetFile(req api.Context) error {
var (
id = req.PathValue("id")
)

thread, err := getUserThread(req, id)
thread, err := getThreadForScope(req)
if err != nil {
return err
}

if thread.Status.WorkspaceID == "" {
return types.NewErrNotFound("no workspace found for assistant %s", id)
return types.NewErrNotFound("no workspace found")
}

return getFileInWorkspace(req.Context(), req, a.gptScript, thread.Status.WorkspaceID, "files/")
Expand Down
62 changes: 48 additions & 14 deletions pkg/api/handlers/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -506,15 +506,13 @@ func (t *TaskHandler) updateCron(req api.Context, workflow *v1.Workflow, task ty
}

func (t *TaskHandler) getAssistantThreadAndManifestFromRequest(req api.Context) (*v1.Agent, *v1.Thread, types.WorkflowManifest, types.TaskManifest, error) {
assistantID := req.PathValue("assistant_id")

assistant, err := getAssistant(req, assistantID)
thread, err := getThreadForScope(req)
if err != nil {
return nil, nil, types.WorkflowManifest{}, types.TaskManifest{}, err
}

thread, err := getUserThread(req, assistantID)
if err != nil {
var agent v1.Agent
if err := req.Get(&agent, thread.Spec.AgentName); err != nil {
return nil, nil, types.WorkflowManifest{}, types.TaskManifest{}, err
}

Expand All @@ -523,7 +521,7 @@ func (t *TaskHandler) getAssistantThreadAndManifestFromRequest(req api.Context)
return nil, nil, types.WorkflowManifest{}, types.TaskManifest{}, err
}

return assistant, thread, toWorkflowManifest(assistant, thread, manifest), manifest, nil
return &agent, thread, toWorkflowManifest(&agent, thread, manifest), manifest, nil
}

func (t *TaskHandler) Create(req api.Context) error {
Expand Down Expand Up @@ -659,15 +657,13 @@ func (t *TaskHandler) Get(req api.Context) error {
}

func (t *TaskHandler) getTask(req api.Context) (*v1.Workflow, *v1.Thread, error) {
assistantID := req.PathValue("assistant_id")

var workflow v1.Workflow
if err := req.Get(&workflow, req.PathValue("id")); err != nil {
thread, err := getThreadForScope(req)
if err != nil {
return nil, nil, err
}

thread, err := getUserThread(req, assistantID)
if err != nil {
var workflow v1.Workflow
if err := req.Get(&workflow, req.PathValue("id")); err != nil {
return nil, nil, err
}

Expand All @@ -678,10 +674,48 @@ func (t *TaskHandler) getTask(req api.Context) (*v1.Workflow, *v1.Thread, error)
return &workflow, thread, nil
}

func (t *TaskHandler) List(req api.Context) error {
func getThreadForScope(req api.Context) (*v1.Thread, error) {
assistantID := req.PathValue("assistant_id")

thread, err := getUserThread(req, assistantID)
if assistantID != "" {
thread, err := getUserThread(req, assistantID)
if err != nil {
return nil, err
}

taskID := req.PathValue("task_id")
runID := req.PathValue("run_id")
if taskID != "" && runID != "" {
if runID == "editor" {
runID = editorWFE(req, taskID)
}
var wfe v1.WorkflowExecution
if err := req.Get(&wfe, runID); err != nil {
return nil, err
}
if wfe.Spec.ThreadName != thread.Name {
return nil, types.NewErrHttp(http.StatusForbidden, "task run does not belong to the thread")
}
if wfe.Spec.WorkflowName != taskID {
return nil, types.NewErrNotFound("task run not found")
}
return thread, req.Get(thread, wfe.Status.ThreadName)
}

return thread, nil
}

threadID := req.PathValue("thread_id")

var thread v1.Thread
if err := req.Get(&thread, threadID); err != nil {
return nil, err
}
return &thread, nil
}

func (t *TaskHandler) List(req api.Context) error {
thread, err := getThreadForScope(req)
if err != nil {
return err
}
Expand Down
11 changes: 10 additions & 1 deletion pkg/api/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func Router(services *services.Services) (http.Handler, error) {
mux.HandleFunc("DELETE /api/assistants/{id}/tools/{tool}", assistants.RemoveTool)
mux.HandleFunc("PUT /api/assistants/{id}/tools/{tool}", assistants.AddTool)
// Assistant files
mux.HandleFunc("GET /api/assistants/{id}/files", assistants.Files)
mux.HandleFunc("GET /api/assistants/{assistant_id}/files", assistants.Files)
mux.HandleFunc("GET /api/assistants/{id}/file/{file...}", assistants.GetFile)
mux.HandleFunc("POST /api/assistants/{id}/files/{file...}", assistants.UploadFile)
mux.HandleFunc("DELETE /api/assistants/{id}/files/{file...}", assistants.DeleteFile)
Expand All @@ -67,15 +67,24 @@ func Router(services *services.Services) (http.Handler, error) {

// Tasks
mux.HandleFunc("GET /api/assistants/{assistant_id}/tasks", tasks.List)
mux.HandleFunc("GET /api/threads/{thread_id}/tasks", tasks.List)
mux.HandleFunc("GET /api/assistants/{assistant_id}/tasks/{id}", tasks.Get)
mux.HandleFunc("GET /api/threads/{thread_id}/tasks/{id}", tasks.Get)
mux.HandleFunc("POST /api/assistants/{assistant_id}/tasks", tasks.Create)
mux.HandleFunc("POST /api/threads/{thread_id}/tasks", tasks.Create)
mux.HandleFunc("PUT /api/assistants/{assistant_id}/tasks/{id}", tasks.Update)
mux.HandleFunc("PUT /api/threads/{thread_id}/tasks/{id}", tasks.Update)
mux.HandleFunc("DELETE /api/assistants/{assistant_id}/tasks/{id}", tasks.Delete)
mux.HandleFunc("POST /api/assistants/{assistant_id}/tasks/{id}/run", tasks.Run)
mux.HandleFunc("POST /api/threads/{thread_id}/tasks/{id}/run", tasks.Run)
mux.HandleFunc("GET /api/assistants/{assistant_id}/tasks/{id}/runs", tasks.ListRuns)
mux.HandleFunc("GET /api/threads/{thread_id}/tasks/{id}/runs", tasks.ListRuns)
mux.HandleFunc("GET /api/assistants/{assistant_id}/tasks/{id}/runs/{run_id}", tasks.GetRun)
mux.HandleFunc("GET /api/threads/{thread_id}/tasks/{id}/runs/{run_id}", tasks.GetRun)
mux.HandleFunc("POST /api/assistants/{assistant_id}/tasks/{id}/runs/{run_id}/abort", tasks.AbortRun)
mux.HandleFunc("DELETE /api/assistants/{assistant_id}/tasks/{id}/runs/{run_id}", tasks.DeleteRun)
mux.HandleFunc("GET /api/assistants/{assistant_id}/tasks/{task_id}/runs/{run_id}/files", assistants.Files)
mux.HandleFunc("GET /api/assistants/{assistant_id}/tasks/{task_id}/runs/{run_id}/file/{file...}", assistants.GetFile)
mux.HandleFunc("GET /api/assistants/{assistant_id}/tasks/{id}/events", tasks.Events)
mux.HandleFunc("POST /api/assistants/{assistant_id}/tasks/{id}/events", tasks.Abort)
mux.HandleFunc("GET /api/assistants/{assistant_id}/tasks/{id}/runs/{run_id}/events", tasks.Events)
Expand Down
6 changes: 5 additions & 1 deletion ui/user/src/lib/actions/div.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import { tick } from 'svelte';

export function autoscroll(node: HTMLElement) {
const observer = new MutationObserver(() => {
if (node.dataset.scroll !== 'false') {
node.scrollTop = node.scrollHeight;
tick().then(() => {
node.scrollTop = node.scrollHeight;
});
}
});

Expand Down
2 changes: 1 addition & 1 deletion ui/user/src/lib/components/Editors.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
}}
/>
{:else if file.name.toLowerCase().endsWith('.png')}
<Image mime="image/png" {file} />
<Image {file} />
{:else}
<Codemirror {file} {onFileChanged} {onInvoke} />
{/if}
Expand Down
7 changes: 3 additions & 4 deletions ui/user/src/lib/components/editor/Image.svelte
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<script lang="ts" >
<script lang="ts">
import type { EditorItem } from '$lib/stores/editor.svelte';
interface Props {
Expand All @@ -20,9 +20,8 @@
return toDataURL(file.blob);
}
});
</script>

{#await src then src}
<img src={src} alt="AI generated, content unknown"/>
{/await}
<img {src} alt="AI generated, content unknown" />
{/await}
2 changes: 1 addition & 1 deletion ui/user/src/lib/components/messages/Input.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
};
for (const file of editor) {
if (file.modified) {
if (file.modified && !file.taskID) {
if (!input.changedFiles) {
input.changedFiles = {};
}
Expand Down
Loading

0 comments on commit 8fa4f94

Please sign in to comment.