fix(server): abort behavior in sse stream

This commit is contained in:
DarkSky 2025-05-09 19:22:29 +08:00
parent e00a37cd00
commit 0b03868450
3 changed files with 139 additions and 74 deletions

View File

@ -49,7 +49,7 @@ import { COPILOT_LOCKER, CopilotType } from '../resolver';
import { ChatSessionService } from '../session';
import { CopilotStorage } from '../storage';
import { MAX_EMBEDDABLE_SIZE } from '../types';
import { readStream } from '../utils';
import { getSignal, readStream } from '../utils';
import { CopilotContextService } from './service';
@InputType()
@ -391,16 +391,6 @@ export class CopilotContextResolver {
private readonly storage: CopilotStorage
) {}
private getSignal(req: Request) {
const controller = new AbortController();
req.socket.on('close', hasError => {
if (hasError) {
controller.abort();
}
});
return controller.signal;
}
@ResolveField(() => [CopilotContextCategory], {
description: 'list collections in context',
})
@ -716,7 +706,7 @@ export class CopilotContextResolver {
context.workspaceId,
content,
limit,
this.getSignal(ctx.req),
getSignal(ctx.req).signal,
threshold
);
}
@ -725,7 +715,7 @@ export class CopilotContextResolver {
return await session.matchFiles(
content,
limit,
this.getSignal(ctx.req),
getSignal(ctx.req).signal,
scopedThreshold,
threshold
);
@ -791,7 +781,7 @@ export class CopilotContextResolver {
context.workspaceId,
content,
limit,
this.getSignal(ctx.req),
getSignal(ctx.req).signal,
threshold
);
}
@ -808,7 +798,7 @@ export class CopilotContextResolver {
const chunks = await session.matchWorkspaceDocs(
content,
limit,
this.getSignal(ctx.req),
getSignal(ctx.req).signal,
scopedThreshold,
threshold
);

View File

@ -13,22 +13,22 @@ import type { Request, Response } from 'express';
import {
BehaviorSubject,
catchError,
concatMap,
connect,
EMPTY,
filter,
finalize,
from,
ignoreElements,
interval,
lastValueFrom,
map,
merge,
mergeMap,
Observable,
reduce,
Subject,
take,
takeUntil,
toArray,
tap,
} from 'rxjs';
import {
@ -50,11 +50,13 @@ import {
CopilotProviderFactory,
ModelInputType,
ModelOutputType,
StreamObject,
} from './providers';
import { StreamObjectParser } from './providers/utils';
import { ChatSession, ChatSessionService } from './session';
import { CopilotStorage } from './storage';
import { ChatMessage, ChatQuerySchema } from './types';
import { getSignal } from './utils';
import { CopilotWorkflowService, GraphExecutorState } from './workflow';
export interface ChatEvent {
@ -156,16 +158,6 @@ export class CopilotController implements BeforeApplicationShutdown {
return [latestMessage, session];
}
private getSignal(req: Request) {
const controller = new AbortController();
req.socket.on('close', hasError => {
if (hasError) {
controller.abort();
}
});
return controller.signal;
}
private parseNumber(value: string | string[] | undefined) {
if (!value) {
return undefined;
@ -255,7 +247,7 @@ export class CopilotController implements BeforeApplicationShutdown {
const { reasoning, webSearch } = ChatQuerySchema.parse(query);
const content = await provider.text({ modelId: model }, finalMessage, {
...session.config.promptConfig,
signal: this.getSignal(req),
signal: getSignal(req).signal,
user: user.id,
workspace: session.config.workspaceId,
reasoning,
@ -305,11 +297,13 @@ export class CopilotController implements BeforeApplicationShutdown {
metrics.ai.counter('chat_stream_calls').add(1, { model });
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
const { signal, onConnectionClosed } = getSignal(req);
const { messageId, reasoning, webSearch } = ChatQuerySchema.parse(query);
const source$ = from(
provider.streamText({ modelId: model }, finalMessage, {
...session.config.promptConfig,
signal: this.getSignal(req),
signal,
user: user.id,
workspace: session.config.workspaceId,
reasoning,
@ -324,16 +318,25 @@ export class CopilotController implements BeforeApplicationShutdown {
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(values => {
session.push({
role: 'assistant',
content: values.join(''),
createdAt: new Date(),
reduce((acc, chunk) => acc + chunk, ''),
tap(buffer => {
onConnectionClosed(isAborted => {
session.push({
role: 'assistant',
content: isAborted ? '> Request aborted' : buffer,
createdAt: new Date(),
});
void session
.save()
.catch(err =>
this.logger.error(
'Failed to save session in sse stream',
err
)
);
});
return from(session.save());
}),
mergeMap(() => EMPTY)
ignoreElements()
)
)
),
@ -378,11 +381,13 @@ export class CopilotController implements BeforeApplicationShutdown {
metrics.ai.counter('chat_object_stream_calls').add(1, { model });
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
const { signal, onConnectionClosed } = getSignal(req);
const { messageId, reasoning, webSearch } = ChatQuerySchema.parse(query);
const source$ = from(
provider.streamObject({ modelId: model }, finalMessage, {
...session.config.promptConfig,
signal: this.getSignal(req),
signal,
user: user.id,
workspace: session.config.workspaceId,
reasoning,
@ -397,20 +402,29 @@ export class CopilotController implements BeforeApplicationShutdown {
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(values => {
const parser = new StreamObjectParser();
const streamObjects = parser.mergeTextDelta(values);
const content = parser.mergeContent(streamObjects);
session.push({
role: 'assistant',
content,
streamObjects,
createdAt: new Date(),
reduce((acc, chunk) => acc.concat([chunk]), [] as StreamObject[]),
tap(result => {
onConnectionClosed(isAborted => {
const parser = new StreamObjectParser();
const streamObjects = parser.mergeTextDelta(result);
const content = parser.mergeContent(streamObjects);
session.push({
role: 'assistant',
content: isAborted ? '> Request aborted' : content,
streamObjects: isAborted ? null : streamObjects,
createdAt: new Date(),
});
void session
.save()
.catch(err =>
this.logger.error(
'Failed to save session in sse stream',
err
)
);
});
return from(session.save());
}),
mergeMap(() => EMPTY)
ignoreElements()
)
)
),
@ -458,10 +472,12 @@ export class CopilotController implements BeforeApplicationShutdown {
});
}
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
const { signal, onConnectionClosed } = getSignal(req);
const source$ = from(
this.workflow.runGraph(params, session.model, {
...session.config.promptConfig,
signal: this.getSignal(req),
signal,
user: user.id,
workspace: session.config.workspaceId,
})
@ -499,19 +515,30 @@ export class CopilotController implements BeforeApplicationShutdown {
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(values => {
session.push({
role: 'assistant',
content: values
.filter(v => v.status === GraphExecutorState.EmitContent)
.map(v => v.content)
.join(''),
createdAt: new Date(),
reduce((acc, chunk) => {
if (chunk.status === GraphExecutorState.EmitContent) {
acc += chunk.content;
}
return acc;
}, ''),
tap(content => {
onConnectionClosed(isAborted => {
session.push({
role: 'assistant',
content: isAborted ? '> Request aborted' : content,
createdAt: new Date(),
});
void session
.save()
.catch(err =>
this.logger.error(
'Failed to save session in sse stream',
err
)
);
});
return from(session.save());
}),
mergeMap(() => EMPTY)
ignoreElements()
)
)
),
@ -571,6 +598,8 @@ export class CopilotController implements BeforeApplicationShutdown {
sessionId
);
this.ongoingStreamCount$.next(this.ongoingStreamCount$.value + 1);
const { signal, onConnectionClosed } = getSignal(req);
const source$ = from(
provider.streamImages(
{
@ -584,7 +613,7 @@ export class CopilotController implements BeforeApplicationShutdown {
...session.config.promptConfig,
quality: params.quality || undefined,
seed: this.parseNumber(params.seed),
signal: this.getSignal(req),
signal,
user: user.id,
workspace: session.config.workspaceId,
}
@ -603,17 +632,26 @@ export class CopilotController implements BeforeApplicationShutdown {
),
// save the generated text to the session
shared$.pipe(
toArray(),
concatMap(attachments => {
session.push({
role: 'assistant',
content: '',
attachments: attachments,
createdAt: new Date(),
reduce((acc, chunk) => acc.concat([chunk]), [] as string[]),
tap(attachments => {
onConnectionClosed(isAborted => {
session.push({
role: 'assistant',
content: isAborted ? '> Request aborted' : '',
attachments: isAborted ? [] : attachments,
createdAt: new Date(),
});
void session
.save()
.catch(err =>
this.logger.error(
'Failed to save session in sse stream',
err
)
);
});
return from(session.save());
}),
mergeMap(() => EMPTY)
ignoreElements()
)
)
),
@ -651,7 +689,7 @@ export class CopilotController implements BeforeApplicationShutdown {
`https://api.unsplash.com/search/photos?${query}`,
{
headers: { Authorization: `Client-ID ${key}` },
signal: this.getSignal(req),
signal: getSignal(req).signal,
}
);

View File

@ -1,5 +1,7 @@
import { Readable } from 'node:stream';
import type { Request } from 'express';
import { readBufferWithLimit } from '../../base';
import { MAX_EMBEDDABLE_SIZE } from './types';
@ -9,3 +11,38 @@ export function readStream(
): Promise<Buffer> {
return readBufferWithLimit(readable, maxSize);
}
type RequestClosedCallback = (isAborted: boolean) => void;
type SignalReturnType = {
signal: AbortSignal;
onConnectionClosed: (cb: RequestClosedCallback) => void;
};
export function getSignal(req: Request): SignalReturnType {
const controller = new AbortController();
let isAborted = true;
let callback: ((isAborted: boolean) => void) | undefined = undefined;
const onSocketEnd = () => {
isAborted = false;
};
const onSocketClose = (hadError: boolean) => {
req.socket.off('end', onSocketEnd);
req.socket.off('close', onSocketClose);
const aborted = hadError || isAborted;
if (aborted) {
controller.abort();
}
callback?.(aborted);
};
req.socket.on('end', onSocketEnd);
req.socket.on('close', onSocketClose);
return {
signal: controller.signal,
onConnectionClosed: cb => (callback = cb),
};
}