From 7b629d347b03955959f6944e40fce391fded9108 Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Thu, 30 Oct 2025 22:33:18 +0000 Subject: [PATCH 1/4] feat(coderd/database): add GetTaskByOwnerIDAndName --- coderd/database/dbauthz/dbauthz.go | 4 +++ coderd/database/dbauthz/dbauthz_test.go | 11 +++++++ coderd/database/dbmetrics/querymetrics.go | 7 +++++ coderd/database/dbmock/dbmock.go | 15 ++++++++++ coderd/database/querier.go | 1 + coderd/database/queries.sql.go | 36 +++++++++++++++++++++++ coderd/database/queries/tasks.sql | 5 ++++ 7 files changed, 79 insertions(+) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 8066ebd0479a1..87b5de36009bf 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2989,6 +2989,10 @@ func (q *querier) GetTaskByID(ctx context.Context, id uuid.UUID) (database.Task, return fetch(q.log, q.auth, q.db.GetTaskByID)(ctx, id) } +func (q *querier) GetTaskByOwnerIDAndName(ctx context.Context, arg database.GetTaskByOwnerIDAndNameParams) (database.Task, error) { + return fetch(q.log, q.auth, q.db.GetTaskByOwnerIDAndName)(ctx, arg) +} + func (q *querier) GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.Task, error) { return fetch(q.log, q.auth, q.db.GetTaskByWorkspaceID)(ctx, workspaceID) } diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 32c951fb5c20b..7d7c136eb543e 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -2375,6 +2375,17 @@ func (s *MethodTestSuite) TestTasks() { dbm.EXPECT().GetTaskByID(gomock.Any(), task.ID).Return(task, nil).AnyTimes() check.Args(task.ID).Asserts(task, policy.ActionRead).Returns(task) })) + s.Run("GetTaskByOwnerIDAndName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + task := testutil.Fake(s.T(), faker, database.Task{}) + dbm.EXPECT().GetTaskByOwnerIDAndName(gomock.Any(), database.GetTaskByOwnerIDAndNameParams{ + OwnerID: task.OwnerID, + Name: task.Name, + }).Return(task, nil).AnyTimes() + check.Args(database.GetTaskByOwnerIDAndNameParams{ + OwnerID: task.OwnerID, + Name: task.Name, + }).Asserts(task, policy.ActionRead).Returns(task) + })) s.Run("DeleteTask", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { task := testutil.Fake(s.T(), faker, database.Task{}) arg := database.DeleteTaskParams{ diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 252f6f9b5ad09..d841315924a15 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -1530,6 +1530,13 @@ func (m queryMetricsStore) GetTaskByID(ctx context.Context, id uuid.UUID) (datab return r0, r1 } +func (m queryMetricsStore) GetTaskByOwnerIDAndName(ctx context.Context, arg database.GetTaskByOwnerIDAndNameParams) (database.Task, error) { + start := time.Now() + r0, r1 := m.s.GetTaskByOwnerIDAndName(ctx, arg) + m.queryLatencies.WithLabelValues("GetTaskByOwnerIDAndName").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m queryMetricsStore) GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.Task, error) { start := time.Now() r0, r1 := m.s.GetTaskByWorkspaceID(ctx, workspaceID) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index af89a987a3203..313bb988979a1 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -3237,6 +3237,21 @@ func (mr *MockStoreMockRecorder) GetTaskByID(ctx, id any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskByID", reflect.TypeOf((*MockStore)(nil).GetTaskByID), ctx, id) } +// GetTaskByOwnerIDAndName mocks base method. +func (m *MockStore) GetTaskByOwnerIDAndName(ctx context.Context, arg database.GetTaskByOwnerIDAndNameParams) (database.Task, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTaskByOwnerIDAndName", ctx, arg) + ret0, _ := ret[0].(database.Task) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetTaskByOwnerIDAndName indicates an expected call of GetTaskByOwnerIDAndName. +func (mr *MockStoreMockRecorder) GetTaskByOwnerIDAndName(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskByOwnerIDAndName", reflect.TypeOf((*MockStore)(nil).GetTaskByOwnerIDAndName), ctx, arg) +} + // GetTaskByWorkspaceID mocks base method. func (m *MockStore) GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (database.Task, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 2739cb7430d9f..3e5771f96de04 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -343,6 +343,7 @@ type sqlcQuerier interface { GetTailnetTunnelPeerBindings(ctx context.Context, srcID uuid.UUID) ([]GetTailnetTunnelPeerBindingsRow, error) GetTailnetTunnelPeerIDs(ctx context.Context, srcID uuid.UUID) ([]GetTailnetTunnelPeerIDsRow, error) GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error) + GetTaskByOwnerIDAndName(ctx context.Context, arg GetTaskByOwnerIDAndNameParams) (Task, error) GetTaskByWorkspaceID(ctx context.Context, workspaceID uuid.UUID) (Task, error) GetTelemetryItem(ctx context.Context, key string) (TelemetryItem, error) GetTelemetryItems(ctx context.Context) ([]TelemetryItem, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 65fac4733b100..3bb36cc9036e5 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -13093,6 +13093,42 @@ func (q *sqlQuerier) GetTaskByID(ctx context.Context, id uuid.UUID) (Task, error return i, err } +const getTaskByOwnerIDAndName = `-- name: GetTaskByOwnerIDAndName :one +SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id, owner_username, owner_name, owner_avatar_url FROM tasks_with_status +WHERE owner_id = $1::uuid + AND LOWER(name) = LOWER($2::text) +` + +type GetTaskByOwnerIDAndNameParams struct { + OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` + Name string `db:"name" json:"name"` +} + +func (q *sqlQuerier) GetTaskByOwnerIDAndName(ctx context.Context, arg GetTaskByOwnerIDAndNameParams) (Task, error) { + row := q.db.QueryRowContext(ctx, getTaskByOwnerIDAndName, arg.OwnerID, arg.Name) + var i Task + err := row.Scan( + &i.ID, + &i.OrganizationID, + &i.OwnerID, + &i.Name, + &i.WorkspaceID, + &i.TemplateVersionID, + &i.TemplateParameters, + &i.Prompt, + &i.CreatedAt, + &i.DeletedAt, + &i.Status, + &i.WorkspaceBuildNumber, + &i.WorkspaceAgentID, + &i.WorkspaceAppID, + &i.OwnerUsername, + &i.OwnerName, + &i.OwnerAvatarUrl, + ) + return i, err +} + const getTaskByWorkspaceID = `-- name: GetTaskByWorkspaceID :one SELECT id, organization_id, owner_id, name, workspace_id, template_version_id, template_parameters, prompt, created_at, deleted_at, status, workspace_build_number, workspace_agent_id, workspace_app_id, owner_username, owner_name, owner_avatar_url FROM tasks_with_status WHERE workspace_id = $1::uuid ` diff --git a/coderd/database/queries/tasks.sql b/coderd/database/queries/tasks.sql index d0617ad39f4dc..dc7f4503140a9 100644 --- a/coderd/database/queries/tasks.sql +++ b/coderd/database/queries/tasks.sql @@ -41,6 +41,11 @@ SELECT * FROM tasks_with_status WHERE id = @id::uuid; -- name: GetTaskByWorkspaceID :one SELECT * FROM tasks_with_status WHERE workspace_id = @workspace_id::uuid; +-- name: GetTaskByOwnerIDAndName :one +SELECT * FROM tasks_with_status +WHERE owner_id = @owner_id::uuid + AND LOWER(name) = LOWER(@name::text); + -- name: ListTasks :many SELECT * FROM tasks_with_status tws WHERE tws.deleted_at IS NULL From ac251e40410d7f8a9bd06f5042bfd7c5a85952bd Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Thu, 30 Oct 2025 22:33:50 +0000 Subject: [PATCH 2/4] chore: add check for authz in TestTasks/Get --- coderd/aitasks_test.go | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/coderd/aitasks_test.go b/coderd/aitasks_test.go index 34f6dd4a0798c..cf75f9a0d24c4 100644 --- a/coderd/aitasks_test.go +++ b/coderd/aitasks_test.go @@ -156,12 +156,13 @@ func TestTasks(t *testing.T) { t.Parallel() var ( - client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) - ctx = testutil.Context(t, testutil.WaitLong) - user = coderdtest.CreateFirstUser(t, client) - template = createAITemplate(t, client, user) - wantPrompt = "review my code" - exp = codersdk.NewExperimentalClient(client) + client, db = coderdtest.NewWithDatabase(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) + ctx = testutil.Context(t, testutil.WaitLong) + user = coderdtest.CreateFirstUser(t, client) + anotherUser, _ = coderdtest.CreateAnotherUser(t, client, user.OrganizationID) + template = createAITemplate(t, client, user) + wantPrompt = "review my code" + exp = codersdk.NewExperimentalClient(client) ) task, err := exp.CreateTask(ctx, "me", codersdk.CreateTaskRequest{ @@ -211,6 +212,13 @@ func TestTasks(t *testing.T) { assert.Equal(t, taskAppID, updated.WorkspaceAppID.UUID, "workspace app id should match") assert.NotEmpty(t, updated.WorkspaceStatus, "task status should not be empty") + // Another member user should not be able to fetch the task + _, err = codersdk.NewExperimentalClient(anotherUser).TaskByID(ctx, task.ID) + require.Error(t, err, "fetching task should fail for another member user") + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusNotFound, sdkErr.StatusCode()) + // Stop the workspace coderdtest.MustTransitionWorkspace(t, client, task.WorkspaceID.UUID, codersdk.WorkspaceTransitionStart, codersdk.WorkspaceTransitionStop) From 6f171b9739bf33ade86beada609e2b46792c33fa Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Mon, 3 Nov 2025 10:15:19 +0000 Subject: [PATCH 3/4] feat(coderd/httpmw): support lookup by task name in TaskParam middlware --- coderd/httpmw/taskparam.go | 116 ++++++++++++-- coderd/httpmw/taskparam_test.go | 262 +++++++++++++++++++++++++------- 2 files changed, 311 insertions(+), 67 deletions(-) diff --git a/coderd/httpmw/taskparam.go b/coderd/httpmw/taskparam.go index 6ecc888b378fe..481cfd7ca20db 100644 --- a/coderd/httpmw/taskparam.go +++ b/coderd/httpmw/taskparam.go @@ -2,8 +2,13 @@ package httpmw import ( "context" + "database/sql" "net/http" + "github.com/go-chi/chi/v5" + "github.com/google/uuid" + "golang.org/x/xerrors" + "cdr.dev/slog" "github.com/coder/coder/v2/coderd/database" @@ -23,32 +28,119 @@ func TaskParam(r *http.Request) database.Task { return task } -// ExtractTaskParam grabs a task from the "task" URL parameter by UUID. +func fetchTaskByUUID(ctx context.Context, db database.Store, taskParam string, _ uuid.UUID) (database.Task, error) { + taskID, err := uuid.Parse(taskParam) + if err != nil { + // Not a valid UUID, skip this strategy + return database.Task{}, sql.ErrNoRows + } + task, err := db.GetTaskByID(ctx, taskID) + if err != nil { + return database.Task{}, xerrors.Errorf("fetch task by uuid: %w", err) + } + return task, nil +} + +func fetchTaskByName(ctx context.Context, db database.Store, taskParam string, ownerID uuid.UUID) (database.Task, error) { + task, err := db.GetTaskByOwnerIDAndName(ctx, database.GetTaskByOwnerIDAndNameParams{ + OwnerID: ownerID, + Name: taskParam, + }) + if err != nil { + return database.Task{}, xerrors.Errorf("fetch task by name: %w", err) + } + return task, nil +} + +func fetchTaskByWorkspaceName(ctx context.Context, db database.Store, taskParam string, ownerID uuid.UUID) (database.Task, error) { + workspace, err := db.GetWorkspaceByOwnerIDAndName(ctx, database.GetWorkspaceByOwnerIDAndNameParams{ + OwnerID: ownerID, + Name: taskParam, + }) + if err != nil { + return database.Task{}, xerrors.Errorf("fetch workspace by name: %w", err) + } + // Check if workspace has an associated task before querying + if !workspace.TaskID.Valid { + return database.Task{}, sql.ErrNoRows + } + task, err := db.GetTaskByWorkspaceID(ctx, workspace.ID) + if err != nil { + return database.Task{}, xerrors.Errorf("fetch task by workspace id: %w", err) + } + return task, nil +} + +type taskFetchFunc func(ctx context.Context, db database.Store, taskParam string, ownerID uuid.UUID) (database.Task, error) + +var ( + taskLookupStrategyNames = []string{"uuid", "name", "workspace"} + taskLookupStrategyFuncs = []taskFetchFunc{fetchTaskByUUID, fetchTaskByName, fetchTaskByWorkspaceName} +) + +// ExtractTaskParam grabs a task from the "task" URL parameter. +// It supports three lookup strategies with cascading fallback: +// 1. Task UUID (primary) +// 2. Task name scoped to owner (secondary) +// 3. Workspace name scoped to owner (tertiary, for legacy links) +// +// This middleware depends on ExtractOrganizationMembersParam being in the chain +// to provide the owner context for name-based lookups. func ExtractTaskParam(db database.Store) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - taskID, parsed := ParseUUIDParam(rw, r, "task") - if !parsed { + + // Get the task parameter value. We can't use ParseUUIDParam here because + // we need to support non-UUID values (task names and workspace names) and + // attempt all lookup strategies. + taskParam := chi.URLParam(r, "task") + if taskParam == "" { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "\"task\" must be provided.", + }) return } - task, err := db.GetTaskByID(ctx, taskID) - if err != nil { - if httpapi.Is404Error(err) { - httpapi.ResourceNotFound(rw) + + // Get owner from OrganizationMembersParam middleware for name-based lookups + members := OrganizationMembersParam(r) + ownerID := members.UserID() + + // Try each strategy in order until one succeeds + var task database.Task + var foundBy string + for i, fetch := range taskLookupStrategyFuncs { + t, err := fetch(ctx, db, taskParam, ownerID) + if err == nil { + task = t + foundBy = taskLookupStrategyNames[i] + break + } + if !httpapi.Is404Error(err) { + // Real error (not just "not found") + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error fetching task.", + Detail: err.Error(), + }) return } - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Internal error fetching task.", - Detail: err.Error(), - }) + // Continue to next strategy on 404 + } + + // If no strategy succeeded, return 404 + if foundBy == "" { + httpapi.ResourceNotFound(rw) return } ctx = context.WithValue(ctx, taskParamContextKey{}, task) if rlogger := loggermw.RequestLoggerFromContext(ctx); rlogger != nil { - rlogger.WithFields(slog.F("task_id", task.ID), slog.F("task_name", task.Name)) + rlogger.WithFields( + slog.F("task_id", task.ID), + slog.F("task_name", task.Name), + slog.F("task_lookup_strategy", foundBy), + ) } next.ServeHTTP(rw, r.WithContext(ctx)) diff --git a/coderd/httpmw/taskparam_test.go b/coderd/httpmw/taskparam_test.go index 559ccc2a2df2d..0640044ed6b7c 100644 --- a/coderd/httpmw/taskparam_test.go +++ b/coderd/httpmw/taskparam_test.go @@ -14,25 +14,119 @@ import ( "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/httpmw" + "github.com/coder/coder/v2/coderd/rbac" + "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/codersdk" ) func TestTaskParam(t *testing.T) { t.Parallel() - setup := func(db database.Store) (*http.Request, database.User) { - user := dbgen.User(t, db, database.User{}) - _, token := dbgen.APIKey(t, db, database.APIKey{ - UserID: user.ID, - }) + // Create all fixtures once - they're only read, never modified + db, _ := dbtestutil.NewDB(t) + user := dbgen.User(t, db, database.User{}) + _, token := dbgen.APIKey(t, db, database.APIKey{ + UserID: user.ID, + }) + org := dbgen.Organization(t, db, database.Organization{}) + tpl := dbgen.Template(t, db, database.Template{ + OrganizationID: org.ID, + CreatedBy: user.ID, + }) + tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + TemplateID: uuid.NullUUID{ + UUID: tpl.ID, + Valid: true, + }, + OrganizationID: org.ID, + CreatedBy: user.ID, + }) + workspace := dbgen.Workspace(t, db, database.WorkspaceTable{ + OwnerID: user.ID, + OrganizationID: org.ID, + TemplateID: tpl.ID, + }) + task := dbgen.Task(t, db, database.TaskTable{ + OrganizationID: org.ID, + OwnerID: user.ID, + TemplateVersionID: tv.ID, + WorkspaceID: uuid.NullUUID{UUID: workspace.ID, Valid: true}, + Prompt: "test prompt", + }) + workspaceNoTask := dbgen.Workspace(t, db, database.WorkspaceTable{ + OwnerID: user.ID, + OrganizationID: org.ID, + TemplateID: tpl.ID, + }) + // Create fixtures with specific names for certain tests + taskMixedCase := dbgen.Task(t, db, database.TaskTable{ + Name: "MyTask", + OrganizationID: org.ID, + OwnerID: user.ID, + TemplateVersionID: tv.ID, + WorkspaceID: uuid.NullUUID{UUID: workspace.ID, Valid: true}, + Prompt: "test prompt", + }) + taskFoundByUUID := dbgen.Task(t, db, database.TaskTable{ + Name: "found-by-uuid", + OrganizationID: org.ID, + OwnerID: user.ID, + TemplateVersionID: tv.ID, + WorkspaceID: uuid.NullUUID{UUID: workspace.ID, Valid: true}, + Prompt: "test prompt", + }) + // To test precedence of UUID over name, we create another task with the same name as the UUID task + _ = dbgen.Task(t, db, database.TaskTable{ + Name: taskFoundByUUID.ID.String(), + OrganizationID: org.ID, + OwnerID: user.ID, + TemplateVersionID: tv.ID, + WorkspaceID: uuid.NullUUID{UUID: workspace.ID, Valid: true}, + Prompt: "test prompt", + }) + workspaceSharedName := dbgen.Workspace(t, db, database.WorkspaceTable{ + Name: "shared-name", + OwnerID: user.ID, + OrganizationID: org.ID, + TemplateID: tpl.ID, + }) + // We create a task with the same name as the workspace shared name. + _ = dbgen.Task(t, db, database.TaskTable{ + Name: "task-different-name", + OrganizationID: org.ID, + OwnerID: user.ID, + TemplateVersionID: tv.ID, + WorkspaceID: uuid.NullUUID{UUID: workspaceSharedName.ID, Valid: true}, + Prompt: "test prompt", + }) + makeRequest := func(userID uuid.UUID, sessionToken string) *http.Request { r := httptest.NewRequest("GET", "/", nil) - r.Header.Set(codersdk.SessionTokenHeader, token) + r.Header.Set(codersdk.SessionTokenHeader, sessionToken) ctx := chi.NewRouteContext() - ctx.URLParams.Add("user", "me") + ctx.URLParams.Add("user", userID.String()) r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, ctx)) - return r, user + return r + } + + makeRouter := func() chi.Router { + rtr := chi.NewRouter() + rtr.Use( + httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ + DB: db, + RedirectToLogin: false, + }), + httpmw.ExtractOrganizationMembersParam(db, func(r *http.Request, _ policy.Action, _ rbac.Objecter) bool { + return true + }), + httpmw.ExtractTaskParam(db), + ) + rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { + _ = httpmw.TaskParam(r) + rw.WriteHeader(http.StatusOK) + }) + return rtr } t.Run("None", func(t *testing.T) { @@ -41,7 +135,8 @@ func TestTaskParam(t *testing.T) { rtr := chi.NewRouter() rtr.Use(httpmw.ExtractTaskParam(db)) rtr.Get("/", nil) - r, _ := setup(db) + r := httptest.NewRequest("GET", "/", nil) + r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, chi.NewRouteContext())) rw := httptest.NewRecorder() rtr.ServeHTTP(rw, r) @@ -52,11 +147,8 @@ func TestTaskParam(t *testing.T) { t.Run("NotFound", func(t *testing.T) { t.Parallel() - db, _ := dbtestutil.NewDB(t) - rtr := chi.NewRouter() - rtr.Use(httpmw.ExtractTaskParam(db)) - rtr.Get("/", nil) - r, _ := setup(db) + rtr := makeRouter() + r := makeRequest(user.ID, token) chi.RouteContext(r.Context()).URLParams.Add("task", uuid.NewString()) rw := httptest.NewRecorder() rtr.ServeHTTP(rw, r) @@ -68,47 +160,8 @@ func TestTaskParam(t *testing.T) { t.Run("Found", func(t *testing.T) { t.Parallel() - db, _ := dbtestutil.NewDB(t) - rtr := chi.NewRouter() - rtr.Use( - httpmw.ExtractAPIKeyMW(httpmw.ExtractAPIKeyConfig{ - DB: db, - RedirectToLogin: false, - }), - httpmw.ExtractTaskParam(db), - ) - rtr.Get("/", func(rw http.ResponseWriter, r *http.Request) { - _ = httpmw.TaskParam(r) - rw.WriteHeader(http.StatusOK) - }) - r, user := setup(db) - org := dbgen.Organization(t, db, database.Organization{}) - tpl := dbgen.Template(t, db, database.Template{ - OrganizationID: org.ID, - CreatedBy: user.ID, - }) - tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{ - TemplateID: uuid.NullUUID{ - UUID: tpl.ID, - Valid: true, - }, - OrganizationID: org.ID, - CreatedBy: user.ID, - }) - workspace := dbgen.Workspace(t, db, database.WorkspaceTable{ - OwnerID: user.ID, - Name: "test-workspace", - OrganizationID: org.ID, - TemplateID: tpl.ID, - }) - task := dbgen.Task(t, db, database.TaskTable{ - Name: "test-task", - OrganizationID: org.ID, - OwnerID: user.ID, - TemplateVersionID: tv.ID, - WorkspaceID: uuid.NullUUID{UUID: workspace.ID, Valid: true}, - Prompt: "test prompt", - }) + rtr := makeRouter() + r := makeRequest(user.ID, token) chi.RouteContext(r.Context()).URLParams.Add("task", task.ID.String()) rw := httptest.NewRecorder() rtr.ServeHTTP(rw, r) @@ -117,4 +170,103 @@ func TestTaskParam(t *testing.T) { defer res.Body.Close() require.Equal(t, http.StatusOK, res.StatusCode) }) + + t.Run("FoundByTaskName", func(t *testing.T) { + t.Parallel() + rtr := makeRouter() + r := makeRequest(user.ID, token) + chi.RouteContext(r.Context()).URLParams.Add("task", task.Name) + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + }) + + t.Run("FoundByWorkspaceName", func(t *testing.T) { + t.Parallel() + rtr := makeRouter() + r := makeRequest(user.ID, token) + chi.RouteContext(r.Context()).URLParams.Add("task", workspace.Name) + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + }) + + t.Run("CaseInsensitiveTaskName", func(t *testing.T) { + t.Parallel() + rtr := makeRouter() + r := makeRequest(user.ID, token) + // Look up with different case + chi.RouteContext(r.Context()).URLParams.Add("task", "mytask") + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + // Verify we got the right task + require.Equal(t, "MyTask", taskMixedCase.Name) + }) + + t.Run("UUIDTakesPrecedence", func(t *testing.T) { + t.Parallel() + rtr := makeRouter() + r := makeRequest(user.ID, token) + // Look up by UUID - should find the first task, not the one named with the UUID + chi.RouteContext(r.Context()).URLParams.Add("task", taskFoundByUUID.ID.String()) + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + // Verify we got the correct task + require.Equal(t, "found-by-uuid", taskFoundByUUID.Name) + }) + + t.Run("TaskNameNotFoundFallsBackToWorkspace", func(t *testing.T) { + t.Parallel() + rtr := makeRouter() + r := makeRequest(user.ID, token) + // Look up by workspace name (which is not a task name) - should fallback + chi.RouteContext(r.Context()).URLParams.Add("task", "shared-name") + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusOK, res.StatusCode) + }) + + t.Run("NotFoundWhenNoMatch", func(t *testing.T) { + t.Parallel() + rtr := makeRouter() + r := makeRequest(user.ID, token) + chi.RouteContext(r.Context()).URLParams.Add("task", "nonexistent-name") + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusNotFound, res.StatusCode) + }) + + t.Run("WorkspaceWithoutTask", func(t *testing.T) { + t.Parallel() + rtr := makeRouter() + r := makeRequest(user.ID, token) + // Look up by workspace name, but workspace has no task + chi.RouteContext(r.Context()).URLParams.Add("task", workspaceNoTask.Name) + rw := httptest.NewRecorder() + rtr.ServeHTTP(rw, r) + + res := rw.Result() + defer res.Body.Close() + require.Equal(t, http.StatusNotFound, res.StatusCode) + }) } From 2ae85072d8b83dd418c0264d881fd2cb3005a22c Mon Sep 17 00:00:00 2001 From: Cian Johnston Date: Mon, 3 Nov 2025 10:36:01 +0000 Subject: [PATCH 4/4] remove lookup by workspace name, refactor and simplify --- coderd/httpmw/taskparam.go | 100 ++++++++++---------------------- coderd/httpmw/taskparam_test.go | 18 +----- 2 files changed, 32 insertions(+), 86 deletions(-) diff --git a/coderd/httpmw/taskparam.go b/coderd/httpmw/taskparam.go index 481cfd7ca20db..a24cb85412e63 100644 --- a/coderd/httpmw/taskparam.go +++ b/coderd/httpmw/taskparam.go @@ -3,6 +3,7 @@ package httpmw import ( "context" "database/sql" + "errors" "net/http" "github.com/go-chi/chi/v5" @@ -28,61 +29,10 @@ func TaskParam(r *http.Request) database.Task { return task } -func fetchTaskByUUID(ctx context.Context, db database.Store, taskParam string, _ uuid.UUID) (database.Task, error) { - taskID, err := uuid.Parse(taskParam) - if err != nil { - // Not a valid UUID, skip this strategy - return database.Task{}, sql.ErrNoRows - } - task, err := db.GetTaskByID(ctx, taskID) - if err != nil { - return database.Task{}, xerrors.Errorf("fetch task by uuid: %w", err) - } - return task, nil -} - -func fetchTaskByName(ctx context.Context, db database.Store, taskParam string, ownerID uuid.UUID) (database.Task, error) { - task, err := db.GetTaskByOwnerIDAndName(ctx, database.GetTaskByOwnerIDAndNameParams{ - OwnerID: ownerID, - Name: taskParam, - }) - if err != nil { - return database.Task{}, xerrors.Errorf("fetch task by name: %w", err) - } - return task, nil -} - -func fetchTaskByWorkspaceName(ctx context.Context, db database.Store, taskParam string, ownerID uuid.UUID) (database.Task, error) { - workspace, err := db.GetWorkspaceByOwnerIDAndName(ctx, database.GetWorkspaceByOwnerIDAndNameParams{ - OwnerID: ownerID, - Name: taskParam, - }) - if err != nil { - return database.Task{}, xerrors.Errorf("fetch workspace by name: %w", err) - } - // Check if workspace has an associated task before querying - if !workspace.TaskID.Valid { - return database.Task{}, sql.ErrNoRows - } - task, err := db.GetTaskByWorkspaceID(ctx, workspace.ID) - if err != nil { - return database.Task{}, xerrors.Errorf("fetch task by workspace id: %w", err) - } - return task, nil -} - -type taskFetchFunc func(ctx context.Context, db database.Store, taskParam string, ownerID uuid.UUID) (database.Task, error) - -var ( - taskLookupStrategyNames = []string{"uuid", "name", "workspace"} - taskLookupStrategyFuncs = []taskFetchFunc{fetchTaskByUUID, fetchTaskByName, fetchTaskByWorkspaceName} -) - // ExtractTaskParam grabs a task from the "task" URL parameter. -// It supports three lookup strategies with cascading fallback: +// It supports two lookup strategies: // 1. Task UUID (primary) // 2. Task name scoped to owner (secondary) -// 3. Workspace name scoped to owner (tertiary, for legacy links) // // This middleware depends on ExtractOrganizationMembersParam being in the chain // to provide the owner context for name-based lookups. @@ -92,7 +42,7 @@ func ExtractTaskParam(db database.Store) func(http.Handler) http.Handler { ctx := r.Context() // Get the task parameter value. We can't use ParseUUIDParam here because - // we need to support non-UUID values (task names and workspace names) and + // we need to support non-UUID values (task names) and // attempt all lookup strategies. taskParam := chi.URLParam(r, "task") if taskParam == "" { @@ -106,29 +56,15 @@ func ExtractTaskParam(db database.Store) func(http.Handler) http.Handler { members := OrganizationMembersParam(r) ownerID := members.UserID() - // Try each strategy in order until one succeeds - var task database.Task - var foundBy string - for i, fetch := range taskLookupStrategyFuncs { - t, err := fetch(ctx, db, taskParam, ownerID) - if err == nil { - task = t - foundBy = taskLookupStrategyNames[i] - break - } + task, err := fetchTaskWithFallback(ctx, db, taskParam, ownerID) + if err != nil { if !httpapi.Is404Error(err) { - // Real error (not just "not found") httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error fetching task.", Detail: err.Error(), }) return } - // Continue to next strategy on 404 - } - - // If no strategy succeeded, return 404 - if foundBy == "" { httpapi.ResourceNotFound(rw) return } @@ -139,7 +75,6 @@ func ExtractTaskParam(db database.Store) func(http.Handler) http.Handler { rlogger.WithFields( slog.F("task_id", task.ID), slog.F("task_name", task.Name), - slog.F("task_lookup_strategy", foundBy), ) } @@ -147,3 +82,28 @@ func ExtractTaskParam(db database.Store) func(http.Handler) http.Handler { }) } } + +func fetchTaskWithFallback(ctx context.Context, db database.Store, taskParam string, ownerID uuid.UUID) (database.Task, error) { + // Attempt to first lookup the task by UUID. + taskID, err := uuid.Parse(taskParam) + if err == nil { + task, err := db.GetTaskByID(ctx, taskID) + if err == nil { + return task, nil + } + // There may be a task named with a valid UUID. Fall back to name lookup in this case. + if !errors.Is(err, sql.ErrNoRows) { + return database.Task{}, xerrors.Errorf("fetch task by uuid: %w", err) + } + } + + // taskParam not a valid UUID, OR valid UUID but not found, so attempt lookup by name. + task, err := db.GetTaskByOwnerIDAndName(ctx, database.GetTaskByOwnerIDAndNameParams{ + OwnerID: ownerID, + Name: taskParam, + }) + if err != nil { + return database.Task{}, xerrors.Errorf("fetch task by name: %w", err) + } + return task, nil +} diff --git a/coderd/httpmw/taskparam_test.go b/coderd/httpmw/taskparam_test.go index 0640044ed6b7c..fe7eb88e40c4d 100644 --- a/coderd/httpmw/taskparam_test.go +++ b/coderd/httpmw/taskparam_test.go @@ -184,7 +184,7 @@ func TestTaskParam(t *testing.T) { require.Equal(t, http.StatusOK, res.StatusCode) }) - t.Run("FoundByWorkspaceName", func(t *testing.T) { + t.Run("NotFoundByWorkspaceName", func(t *testing.T) { t.Parallel() rtr := makeRouter() r := makeRequest(user.ID, token) @@ -194,7 +194,7 @@ func TestTaskParam(t *testing.T) { res := rw.Result() defer res.Body.Close() - require.Equal(t, http.StatusOK, res.StatusCode) + require.Equal(t, http.StatusNotFound, res.StatusCode) }) t.Run("CaseInsensitiveTaskName", func(t *testing.T) { @@ -229,20 +229,6 @@ func TestTaskParam(t *testing.T) { require.Equal(t, "found-by-uuid", taskFoundByUUID.Name) }) - t.Run("TaskNameNotFoundFallsBackToWorkspace", func(t *testing.T) { - t.Parallel() - rtr := makeRouter() - r := makeRequest(user.ID, token) - // Look up by workspace name (which is not a task name) - should fallback - chi.RouteContext(r.Context()).URLParams.Add("task", "shared-name") - rw := httptest.NewRecorder() - rtr.ServeHTTP(rw, r) - - res := rw.Result() - defer res.Body.Close() - require.Equal(t, http.StatusOK, res.StatusCode) - }) - t.Run("NotFoundWhenNoMatch", func(t *testing.T) { t.Parallel() rtr := makeRouter()