Skip to content

Commit ac9d73a

Browse files
committed
Use lepton for llama2
1 parent 1225017 commit ac9d73a

4 files changed

Lines changed: 134 additions & 112 deletions

File tree

src/app/bots/gradio/index.ts

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import WebSocketAsPromised from 'websocket-as-promised'
2+
import { ChatError, ErrorCode } from '~utils/errors'
3+
import { AbstractBot, SendMessageParams } from '../abstract-bot'
4+
import { html2md } from '~app/utils/markdown'
5+
6+
function generateSessionHash() {
7+
// https://stackoverflow.com/a/12502559/325241
8+
return Math.random().toString(36).substring(2)
9+
}
10+
11+
enum FnIndex {
12+
Send = 7,
13+
Receive = 8,
14+
}
15+
16+
interface ConversationContext {
17+
sessionHash: string
18+
}
19+
20+
export class GradioBot extends AbstractBot {
21+
private conversationContext?: ConversationContext
22+
23+
constructor(public wsUrl: string, public model: string, public params: number[], public mode?: 'text' | 'html') {
24+
super()
25+
}
26+
27+
async doSendMessage(params: SendMessageParams) {
28+
if (!this.conversationContext) {
29+
const sessionHash = await this.createSession(params.signal)
30+
this.conversationContext = { sessionHash }
31+
}
32+
33+
const sendWsp = await this.connectWebsocket(
34+
FnIndex.Send,
35+
this.conversationContext.sessionHash,
36+
[null, this.model, params.prompt],
37+
params.onEvent,
38+
)
39+
const receiveWsp = await this.connectWebsocket(
40+
FnIndex.Receive,
41+
this.conversationContext.sessionHash,
42+
[null, ...this.params],
43+
params.onEvent,
44+
)
45+
46+
params.signal?.addEventListener('abort', () => {
47+
;[sendWsp, receiveWsp].forEach((wsp) => {
48+
wsp.removeAllListeners()
49+
wsp.close()
50+
})
51+
})
52+
}
53+
54+
async connectWebsocket(fnIndex: number, sessionHash: string, data: unknown[], onEvent: SendMessageParams['onEvent']) {
55+
const wsp = new WebSocketAsPromised(this.wsUrl, {
56+
packMessage: (data) => JSON.stringify(data),
57+
unpackMessage: (data) => JSON.parse(data as string),
58+
})
59+
60+
wsp.onUnpackedMessage.addListener(async (event) => {
61+
if (event.msg === 'send_hash') {
62+
wsp.sendPacked({ fn_index: fnIndex, session_hash: sessionHash })
63+
} else if (event.msg === 'send_data') {
64+
wsp.sendPacked({
65+
fn_index: fnIndex,
66+
data,
67+
event_data: null,
68+
session_hash: sessionHash,
69+
})
70+
} else if (event.msg === 'process_generating') {
71+
if (event.success && event.output.data) {
72+
if (fnIndex === FnIndex.Receive) {
73+
const outputData = event.output.data
74+
if (outputData[1].length > 0) {
75+
const text = outputData[1][outputData[1].length - 1][1]
76+
onEvent({
77+
type: 'UPDATE_ANSWER',
78+
data: {
79+
text: this.mode === 'html' ? html2md(text) : text,
80+
},
81+
})
82+
}
83+
}
84+
} else {
85+
onEvent({ type: 'ERROR', error: new ChatError(event.output.error, ErrorCode.UNKOWN_ERROR) })
86+
}
87+
} else if (event.msg === 'queue_full') {
88+
onEvent({ type: 'ERROR', error: new ChatError('queue_full', ErrorCode.UNKOWN_ERROR) })
89+
} else if (event.msg === 'process_completed' && fnIndex === FnIndex.Receive && !event.output.data[1].length) {
90+
onEvent({
91+
type: 'ERROR',
92+
error: new ChatError('Session has been inactive for too long', ErrorCode.LMSYS_SESSION_EXPIRED),
93+
})
94+
}
95+
})
96+
97+
if (fnIndex === FnIndex.Receive) {
98+
wsp.onClose.addListener(() => {
99+
wsp.removeAllListeners()
100+
onEvent({ type: 'DONE' })
101+
})
102+
}
103+
104+
try {
105+
await wsp.open()
106+
} catch (err) {
107+
console.error('lmsys ws open error', err)
108+
throw new ChatError('Failed to establish websocket connection.', ErrorCode.NETWORK_ERROR)
109+
}
110+
111+
return wsp
112+
}
113+
114+
resetConversation() {
115+
this.conversationContext = undefined
116+
}
117+
118+
public async createSession(_signal?: AbortSignal) {
119+
return generateSessionHash()
120+
}
121+
}

src/app/bots/index.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import { BardBot } from './bard'
22
import { BingWebBot } from './bing'
33
import { ChatGPTBot } from './chatgpt'
44
import { ClaudeBot } from './claude'
5+
import { GradioBot } from './gradio'
56
import { LMSYSBot } from './lmsys'
67
import { PiBot } from './pi'
78
import { XunfeiBot } from './xunfei'
@@ -41,7 +42,7 @@ export function createBotInstance(botId: BotId) {
4142
case 'chatglm':
4243
return new LMSYSBot('chatglm2-6b')
4344
case 'llama':
44-
return new LMSYSBot('llama-2-13b-chat')
45+
return new GradioBot('wss://llama2.lepton.run/chat/queue/join', 'llama2', [0.5, 0.8, 512], 'html')
4546
case 'oasst':
4647
return new LMSYSBot('oasst-pythia-12b')
4748
case 'rwkv':

src/app/bots/lmsys/index.ts

Lines changed: 11 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1,114 +1,18 @@
11
import WebSocketAsPromised from 'websocket-as-promised'
2-
import { ChatError, ErrorCode } from '~utils/errors'
3-
import { AbstractBot, SendMessageParams } from '../abstract-bot'
4-
import { generateSessionHash } from './utils'
5-
6-
enum FnIndex {
7-
Send = 7,
8-
Receive = 8,
9-
}
10-
11-
interface ConversationContext {
12-
sessionHash: string
13-
}
14-
15-
export class LMSYSBot extends AbstractBot {
16-
public model: string
17-
private conversationContext?: ConversationContext
2+
import { GradioBot } from '../gradio'
183

4+
export class LMSYSBot extends GradioBot {
195
constructor(model: string) {
20-
super()
21-
this.model = model
22-
}
23-
24-
async doSendMessage(params: SendMessageParams) {
25-
if (!this.conversationContext) {
26-
const sessionHash = await this.createSession(params.signal)
27-
this.conversationContext = { sessionHash }
28-
}
29-
30-
const sendWsp = await this.connectWebsocket(
31-
FnIndex.Send,
32-
this.conversationContext.sessionHash,
33-
[null, this.model, params.prompt],
34-
params.onEvent,
35-
)
36-
const receiveWsp = await this.connectWebsocket(
37-
FnIndex.Receive,
38-
this.conversationContext.sessionHash,
39-
[null, 0.7, 1, 512],
40-
params.onEvent,
41-
)
42-
43-
params.signal?.addEventListener('abort', () => {
44-
;[sendWsp, receiveWsp].forEach((wsp) => {
45-
wsp.removeAllListeners()
46-
wsp.close()
47-
})
48-
})
49-
}
50-
51-
async connectWebsocket(fnIndex: number, sessionHash: string, data: unknown[], onEvent: SendMessageParams['onEvent']) {
52-
const wsp = new WebSocketAsPromised('wss://chat.lmsys.org/queue/join', {
53-
packMessage: (data) => JSON.stringify(data),
54-
unpackMessage: (data) => JSON.parse(data as string),
55-
})
56-
57-
wsp.onUnpackedMessage.addListener(async (event) => {
58-
if (event.msg === 'send_hash') {
59-
wsp.sendPacked({ fn_index: fnIndex, session_hash: sessionHash })
60-
} else if (event.msg === 'send_data') {
61-
wsp.sendPacked({
62-
fn_index: fnIndex,
63-
data,
64-
event_data: null,
65-
session_hash: sessionHash,
66-
})
67-
} else if (event.msg === 'process_generating') {
68-
if (event.success && event.output.data) {
69-
if (fnIndex === FnIndex.Receive) {
70-
const outputData = event.output.data
71-
if (outputData[1].length > 0) {
72-
const text = outputData[1][outputData[1].length - 1][1]
73-
onEvent({ type: 'UPDATE_ANSWER', data: { text } })
74-
}
75-
}
76-
} else {
77-
onEvent({ type: 'ERROR', error: new ChatError(event.output.error, ErrorCode.UNKOWN_ERROR) })
78-
}
79-
} else if (event.msg === 'queue_full') {
80-
onEvent({ type: 'ERROR', error: new ChatError('queue_full', ErrorCode.UNKOWN_ERROR) })
81-
} else if (event.msg === 'process_completed' && fnIndex === FnIndex.Receive && !event.output.data[1].length) {
82-
onEvent({
83-
type: 'ERROR',
84-
error: new ChatError('Session has been inactive for too long', ErrorCode.LMSYS_SESSION_EXPIRED),
85-
})
86-
}
87-
})
88-
89-
if (fnIndex === FnIndex.Receive) {
90-
wsp.onClose.addListener(() => {
91-
wsp.removeAllListeners()
92-
onEvent({ type: 'DONE' })
93-
})
94-
}
95-
96-
try {
97-
await wsp.open()
98-
} catch (err) {
99-
console.error('lmsys ws open error', err)
100-
throw new ChatError('Failed to establish websocket connection.', ErrorCode.NETWORK_ERROR)
101-
}
102-
103-
return wsp
104-
}
105-
106-
resetConversation() {
107-
this.conversationContext = undefined
6+
super('wss://chat.lmsys.org/queue/join', model, [0.7, 1, 512], 'text')
1087
}
1098

110-
async initializeSession(fnIndex: number, sessionHash: string, data: unknown[], signal?: AbortSignal): Promise<void> {
111-
const wsp = new WebSocketAsPromised('wss://chat.lmsys.org/queue/join', {
9+
private async initializeSession(
10+
fnIndex: number,
11+
sessionHash: string,
12+
data: unknown[],
13+
signal?: AbortSignal,
14+
): Promise<void> {
15+
const wsp = new WebSocketAsPromised(this.wsUrl, {
11216
packMessage: (data) => JSON.stringify(data),
11317
unpackMessage: (data) => JSON.parse(data as string),
11418
})
@@ -128,7 +32,7 @@ export class LMSYSBot extends AbstractBot {
12832
}
12933

13034
async createSession(signal?: AbortSignal) {
131-
const sessionHash = generateSessionHash()
35+
const sessionHash = await super.createSession(signal)
13236
await Promise.all([
13337
this.initializeSession(36, sessionHash, [], signal),
13438
this.initializeSession(43, sessionHash, [{}], signal),

src/app/bots/lmsys/utils.ts

Lines changed: 0 additions & 4 deletions
This file was deleted.

0 commit comments

Comments
 (0)