Skip to content

Commit c795fc8

Browse files
feat(azure-openai): allow usage of azure-openai for knowledgebase uploads and wand generation (#1056)
* feat(azure-openai): allow usage of azure-openai for knowledgebase uploads * feat(azure-openai): added azure-openai for kb and wand * added embeddings utils, added the ability to use mistral through Azure * fix(oauth): gdrive picker race condition, token route cleanup * fix test * feat(mailer): consolidated all emailing to mailer service, added support for Azure ACS (#1054) * feat(mailer): consolidated all emailing to mailer service, added support for Azure ACS * fix batch invitation email template * cleanup * improvement(emails): add help template instead of doing it inline * remove fallback version --------- Co-authored-by: Vikhyath Mondreti <vikhyath@simstudio.ai>
1 parent cea42f5 commit c795fc8

8 files changed

Lines changed: 851 additions & 470 deletions

File tree

apps/sim/app/api/knowledge/search/utils.test.ts

Lines changed: 285 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,50 @@
44
*
55
* @vitest-environment node
66
*/
7-
import { describe, expect, it, vi } from 'vitest'
7+
import { beforeEach, describe, expect, it, vi } from 'vitest'
88

99
vi.mock('drizzle-orm')
10-
vi.mock('@/lib/logs/console/logger')
10+
vi.mock('@/lib/logs/console/logger', () => ({
11+
createLogger: vi.fn(() => ({
12+
info: vi.fn(),
13+
debug: vi.fn(),
14+
warn: vi.fn(),
15+
error: vi.fn(),
16+
})),
17+
}))
1118
vi.mock('@/db')
19+
vi.mock('@/lib/documents/utils', () => ({
20+
retryWithExponentialBackoff: (fn: any) => fn(),
21+
}))
1222

13-
import { handleTagAndVectorSearch, handleTagOnlySearch, handleVectorOnlySearch } from './utils'
23+
vi.stubGlobal(
24+
'fetch',
25+
vi.fn().mockResolvedValue({
26+
ok: true,
27+
json: async () => ({
28+
data: [{ embedding: [0.1, 0.2, 0.3] }],
29+
}),
30+
})
31+
)
32+
33+
vi.mock('@/lib/env', () => ({
34+
env: {},
35+
isTruthy: (value: string | boolean | number | undefined) =>
36+
typeof value === 'string' ? value === 'true' || value === '1' : Boolean(value),
37+
}))
38+
39+
import {
40+
generateSearchEmbedding,
41+
handleTagAndVectorSearch,
42+
handleTagOnlySearch,
43+
handleVectorOnlySearch,
44+
} from './utils'
1445

1546
describe('Knowledge Search Utils', () => {
47+
beforeEach(() => {
48+
vi.clearAllMocks()
49+
})
50+
1651
describe('handleTagOnlySearch', () => {
1752
it('should throw error when no filters provided', async () => {
1853
const params = {
@@ -140,4 +175,251 @@ describe('Knowledge Search Utils', () => {
140175
expect(params.distanceThreshold).toBe(0.8)
141176
})
142177
})
178+
179+
describe('generateSearchEmbedding', () => {
180+
it('should use Azure OpenAI when KB-specific config is provided', async () => {
181+
const { env } = await import('@/lib/env')
182+
Object.keys(env).forEach((key) => delete (env as any)[key])
183+
Object.assign(env, {
184+
AZURE_OPENAI_API_KEY: 'test-azure-key',
185+
AZURE_OPENAI_ENDPOINT: 'https://test.openai.azure.com',
186+
AZURE_OPENAI_API_VERSION: '2024-12-01-preview',
187+
KB_OPENAI_MODEL_NAME: 'text-embedding-ada-002',
188+
OPENAI_API_KEY: 'test-openai-key',
189+
})
190+
191+
const fetchSpy = vi.mocked(fetch)
192+
fetchSpy.mockResolvedValueOnce({
193+
ok: true,
194+
json: async () => ({
195+
data: [{ embedding: [0.1, 0.2, 0.3] }],
196+
}),
197+
} as any)
198+
199+
const result = await generateSearchEmbedding('test query')
200+
201+
expect(fetchSpy).toHaveBeenCalledWith(
202+
'https://test.openai.azure.com/openai/deployments/text-embedding-ada-002/embeddings?api-version=2024-12-01-preview',
203+
expect.objectContaining({
204+
headers: expect.objectContaining({
205+
'api-key': 'test-azure-key',
206+
}),
207+
})
208+
)
209+
expect(result).toEqual([0.1, 0.2, 0.3])
210+
211+
// Clean up
212+
Object.keys(env).forEach((key) => delete (env as any)[key])
213+
})
214+
215+
it('should fallback to OpenAI when no KB Azure config provided', async () => {
216+
const { env } = await import('@/lib/env')
217+
Object.keys(env).forEach((key) => delete (env as any)[key])
218+
Object.assign(env, {
219+
OPENAI_API_KEY: 'test-openai-key',
220+
})
221+
222+
const fetchSpy = vi.mocked(fetch)
223+
fetchSpy.mockResolvedValueOnce({
224+
ok: true,
225+
json: async () => ({
226+
data: [{ embedding: [0.1, 0.2, 0.3] }],
227+
}),
228+
} as any)
229+
230+
const result = await generateSearchEmbedding('test query')
231+
232+
expect(fetchSpy).toHaveBeenCalledWith(
233+
'https://api.openai.com/v1/embeddings',
234+
expect.objectContaining({
235+
headers: expect.objectContaining({
236+
Authorization: 'Bearer test-openai-key',
237+
}),
238+
})
239+
)
240+
expect(result).toEqual([0.1, 0.2, 0.3])
241+
242+
// Clean up
243+
Object.keys(env).forEach((key) => delete (env as any)[key])
244+
})
245+
246+
it('should use default API version when not provided in Azure config', async () => {
247+
const { env } = await import('@/lib/env')
248+
Object.keys(env).forEach((key) => delete (env as any)[key])
249+
Object.assign(env, {
250+
AZURE_OPENAI_API_KEY: 'test-azure-key',
251+
AZURE_OPENAI_ENDPOINT: 'https://test.openai.azure.com',
252+
KB_OPENAI_MODEL_NAME: 'custom-embedding-model',
253+
OPENAI_API_KEY: 'test-openai-key',
254+
})
255+
256+
const fetchSpy = vi.mocked(fetch)
257+
fetchSpy.mockResolvedValueOnce({
258+
ok: true,
259+
json: async () => ({
260+
data: [{ embedding: [0.1, 0.2, 0.3] }],
261+
}),
262+
} as any)
263+
264+
await generateSearchEmbedding('test query')
265+
266+
expect(fetchSpy).toHaveBeenCalledWith(
267+
expect.stringContaining('api-version='),
268+
expect.any(Object)
269+
)
270+
271+
// Clean up
272+
Object.keys(env).forEach((key) => delete (env as any)[key])
273+
})
274+
275+
it('should use custom model name when provided in Azure config', async () => {
276+
const { env } = await import('@/lib/env')
277+
Object.keys(env).forEach((key) => delete (env as any)[key])
278+
Object.assign(env, {
279+
AZURE_OPENAI_API_KEY: 'test-azure-key',
280+
AZURE_OPENAI_ENDPOINT: 'https://test.openai.azure.com',
281+
AZURE_OPENAI_API_VERSION: '2024-12-01-preview',
282+
KB_OPENAI_MODEL_NAME: 'custom-embedding-model',
283+
OPENAI_API_KEY: 'test-openai-key',
284+
})
285+
286+
const fetchSpy = vi.mocked(fetch)
287+
fetchSpy.mockResolvedValueOnce({
288+
ok: true,
289+
json: async () => ({
290+
data: [{ embedding: [0.1, 0.2, 0.3] }],
291+
}),
292+
} as any)
293+
294+
await generateSearchEmbedding('test query', 'text-embedding-3-small')
295+
296+
expect(fetchSpy).toHaveBeenCalledWith(
297+
'https://test.openai.azure.com/openai/deployments/custom-embedding-model/embeddings?api-version=2024-12-01-preview',
298+
expect.any(Object)
299+
)
300+
301+
// Clean up
302+
Object.keys(env).forEach((key) => delete (env as any)[key])
303+
})
304+
305+
it('should throw error when no API configuration provided', async () => {
306+
const { env } = await import('@/lib/env')
307+
Object.keys(env).forEach((key) => delete (env as any)[key])
308+
309+
await expect(generateSearchEmbedding('test query')).rejects.toThrow(
310+
'Either OPENAI_API_KEY or Azure OpenAI configuration (AZURE_OPENAI_API_KEY + AZURE_OPENAI_ENDPOINT) must be configured'
311+
)
312+
})
313+
314+
it('should handle Azure OpenAI API errors properly', async () => {
315+
const { env } = await import('@/lib/env')
316+
Object.keys(env).forEach((key) => delete (env as any)[key])
317+
Object.assign(env, {
318+
AZURE_OPENAI_API_KEY: 'test-azure-key',
319+
AZURE_OPENAI_ENDPOINT: 'https://test.openai.azure.com',
320+
AZURE_OPENAI_API_VERSION: '2024-12-01-preview',
321+
KB_OPENAI_MODEL_NAME: 'text-embedding-ada-002',
322+
})
323+
324+
const fetchSpy = vi.mocked(fetch)
325+
fetchSpy.mockResolvedValueOnce({
326+
ok: false,
327+
status: 404,
328+
statusText: 'Not Found',
329+
text: async () => 'Deployment not found',
330+
} as any)
331+
332+
await expect(generateSearchEmbedding('test query')).rejects.toThrow('Embedding API failed')
333+
334+
// Clean up
335+
Object.keys(env).forEach((key) => delete (env as any)[key])
336+
})
337+
338+
it('should handle OpenAI API errors properly', async () => {
339+
const { env } = await import('@/lib/env')
340+
Object.keys(env).forEach((key) => delete (env as any)[key])
341+
Object.assign(env, {
342+
OPENAI_API_KEY: 'test-openai-key',
343+
})
344+
345+
const fetchSpy = vi.mocked(fetch)
346+
fetchSpy.mockResolvedValueOnce({
347+
ok: false,
348+
status: 429,
349+
statusText: 'Too Many Requests',
350+
text: async () => 'Rate limit exceeded',
351+
} as any)
352+
353+
await expect(generateSearchEmbedding('test query')).rejects.toThrow('Embedding API failed')
354+
355+
// Clean up
356+
Object.keys(env).forEach((key) => delete (env as any)[key])
357+
})
358+
359+
it('should include correct request body for Azure OpenAI', async () => {
360+
const { env } = await import('@/lib/env')
361+
Object.keys(env).forEach((key) => delete (env as any)[key])
362+
Object.assign(env, {
363+
AZURE_OPENAI_API_KEY: 'test-azure-key',
364+
AZURE_OPENAI_ENDPOINT: 'https://test.openai.azure.com',
365+
AZURE_OPENAI_API_VERSION: '2024-12-01-preview',
366+
KB_OPENAI_MODEL_NAME: 'text-embedding-ada-002',
367+
})
368+
369+
const fetchSpy = vi.mocked(fetch)
370+
fetchSpy.mockResolvedValueOnce({
371+
ok: true,
372+
json: async () => ({
373+
data: [{ embedding: [0.1, 0.2, 0.3] }],
374+
}),
375+
} as any)
376+
377+
await generateSearchEmbedding('test query')
378+
379+
expect(fetchSpy).toHaveBeenCalledWith(
380+
expect.any(String),
381+
expect.objectContaining({
382+
body: JSON.stringify({
383+
input: ['test query'],
384+
encoding_format: 'float',
385+
}),
386+
})
387+
)
388+
389+
// Clean up
390+
Object.keys(env).forEach((key) => delete (env as any)[key])
391+
})
392+
393+
it('should include correct request body for OpenAI', async () => {
394+
const { env } = await import('@/lib/env')
395+
Object.keys(env).forEach((key) => delete (env as any)[key])
396+
Object.assign(env, {
397+
OPENAI_API_KEY: 'test-openai-key',
398+
})
399+
400+
const fetchSpy = vi.mocked(fetch)
401+
fetchSpy.mockResolvedValueOnce({
402+
ok: true,
403+
json: async () => ({
404+
data: [{ embedding: [0.1, 0.2, 0.3] }],
405+
}),
406+
} as any)
407+
408+
await generateSearchEmbedding('test query', 'text-embedding-3-small')
409+
410+
expect(fetchSpy).toHaveBeenCalledWith(
411+
expect.any(String),
412+
expect.objectContaining({
413+
body: JSON.stringify({
414+
input: ['test query'],
415+
model: 'text-embedding-3-small',
416+
encoding_format: 'float',
417+
}),
418+
})
419+
)
420+
421+
// Clean up
422+
Object.keys(env).forEach((key) => delete (env as any)[key])
423+
})
424+
})
143425
})

apps/sim/app/api/knowledge/search/utils.ts

Lines changed: 2 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,10 @@
11
import { and, eq, inArray, sql } from 'drizzle-orm'
2-
import { retryWithExponentialBackoff } from '@/lib/documents/utils'
3-
import { env } from '@/lib/env'
42
import { createLogger } from '@/lib/logs/console/logger'
53
import { db } from '@/db'
64
import { embedding } from '@/db/schema'
75

86
const logger = createLogger('KnowledgeSearchUtils')
97

10-
export class APIError extends Error {
11-
public status: number
12-
13-
constructor(message: string, status: number) {
14-
super(message)
15-
this.name = 'APIError'
16-
this.status = status
17-
}
18-
}
19-
208
export interface SearchResult {
219
id: string
2210
content: string
@@ -41,61 +29,8 @@ export interface SearchParams {
4129
distanceThreshold?: number
4230
}
4331

44-
export async function generateSearchEmbedding(query: string): Promise<number[]> {
45-
const openaiApiKey = env.OPENAI_API_KEY
46-
if (!openaiApiKey) {
47-
throw new Error('OPENAI_API_KEY not configured')
48-
}
49-
50-
try {
51-
const embedding = await retryWithExponentialBackoff(
52-
async () => {
53-
const response = await fetch('https://api.openai.com/v1/embeddings', {
54-
method: 'POST',
55-
headers: {
56-
Authorization: `Bearer ${openaiApiKey}`,
57-
'Content-Type': 'application/json',
58-
},
59-
body: JSON.stringify({
60-
input: query,
61-
model: 'text-embedding-3-small',
62-
encoding_format: 'float',
63-
}),
64-
})
65-
66-
if (!response.ok) {
67-
const errorText = await response.text()
68-
const error = new APIError(
69-
`OpenAI API error: ${response.status} ${response.statusText} - ${errorText}`,
70-
response.status
71-
)
72-
throw error
73-
}
74-
75-
const data = await response.json()
76-
77-
if (!data.data || !Array.isArray(data.data) || data.data.length === 0) {
78-
throw new Error('Invalid response format from OpenAI embeddings API')
79-
}
80-
81-
return data.data[0].embedding
82-
},
83-
{
84-
maxRetries: 5,
85-
initialDelayMs: 1000,
86-
maxDelayMs: 30000,
87-
backoffMultiplier: 2,
88-
}
89-
)
90-
91-
return embedding
92-
} catch (error) {
93-
logger.error('Failed to generate search embedding:', error)
94-
throw new Error(
95-
`Embedding generation failed: ${error instanceof Error ? error.message : 'Unknown error'}`
96-
)
97-
}
98-
}
32+
// Use shared embedding utility
33+
export { generateSearchEmbedding } from '@/lib/embeddings/utils'
9934

10035
function getTagFilters(filters: Record<string, string>, embedding: any) {
10136
return Object.entries(filters).map(([key, value]) => {

0 commit comments

Comments
 (0)