Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions api-docs/swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -4778,6 +4778,10 @@
"description": "OpenAI settings",
"type": "string"
},
"api_url": {
"description": "Custom transcription API base URL (OpenAI adapter only)",
"type": "string"
},
"attention_context_left": {
"description": "NVIDIA Parakeet-specific parameters for long-form audio",
"type": "integer"
Expand Down Expand Up @@ -4930,6 +4934,10 @@
"threads": {
"type": "integer"
},
"timeout_minutes": {
"description": "HTTP request timeout in minutes (OpenAI adapter with custom base URL)",
"type": "integer"
},
"vad_method": {
"description": "VAD (Voice Activity Detection) settings",
"type": "string"
Expand Down
6 changes: 6 additions & 0 deletions api-docs/swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,9 @@ definitions:
api_key:
description: OpenAI settings
type: string
api_url:
description: Custom transcription API base URL (OpenAI adapter only)
type: string
attention_context_left:
description: NVIDIA Parakeet-specific parameters for long-form audio
type: integer
Expand Down Expand Up @@ -747,6 +750,9 @@ definitions:
type: number
threads:
type: integer
timeout_minutes:
description: HTTP request timeout in minutes (OpenAI adapter with custom base URL)
type: integer
vad_method:
description: VAD (Voice Activity Detection) settings
type: string
Expand Down
2 changes: 1 addition & 1 deletion internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func Load() *Config {
TempDir: getEnv("TEMP_DIR", "data/temp"),
WhisperXEnv: getEnv("WHISPERX_ENV", "data/whisperx-env"),
SecureCookies: getEnv("SECURE_COOKIES", defaultSecure) == "true",
OpenAIAPIKey: getEnv("OPENAI_API_KEY", ""),
OpenAIAPIKey: getEnv("OPENAI_API_KEY", ""),
HFToken: getEnv("HF_TOKEN", ""),
}
}
Expand Down
4 changes: 3 additions & 1 deletion internal/models/transcription.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ type WhisperXParams struct {
CallbackURL *string `json:"callback_url,omitempty" gorm:"type:text"`

// OpenAI settings
APIKey *string `json:"api_key,omitempty" gorm:"type:text"`
APIKey *string `json:"api_key,omitempty" gorm:"type:text"`
APIURL *string `json:"api_url,omitempty" gorm:"type:text"`
TimeoutMinutes *int `json:"timeout_minutes,omitempty" gorm:"type:int"`

// Voxtral settings
MaxNewTokens *int `json:"max_new_tokens,omitempty" gorm:"type:int"`
Expand Down
12 changes: 11 additions & 1 deletion internal/transcription/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ err := adapter.ValidateParameters(params)
| `whisperx` | `whisper` | 90+ languages | Timestamps, Diarization, Translation |
| `parakeet` | `nvidia_parakeet` | English only | Timestamps, Long-form, High Quality |
| `canary` | `nvidia_canary` | 12 languages | Timestamps, Translation, Multilingual |
| `openai_whisper` | `openai` | 57 languages | Timestamps, Diarization, Translation, Custom Endpoint |

### Diarization Models

Expand Down Expand Up @@ -221,9 +222,18 @@ params := map[string]interface{}{
// NVIDIA Canary with translation
params := map[string]interface{}{
"source_lang": "es",
"target_lang": "en",
"target_lang": "en",
"task": "translate",
}

// OpenAI with custom self-hosted endpoint
params := map[string]interface{}{
"base_url": "http://localhost:8000/v1",
"model": "Systran/faster-whisper-large-v3",
"timeout_minutes": 30,
"diarize": true,
"diarize_model": "pyannote",
}
```

## Testing
Expand Down
50 changes: 38 additions & 12 deletions internal/transcription/adapters/openai_adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func NewOpenAIAdapter(apiKey string) *OpenAIAdapter {
Features: map[string]bool{
"timestamps": true, // Verbose JSON response includes segments
"word_level": false, // Not supported by standard API yet (unless using verbose_json with timestamp_granularities which is beta)
"diarization": false, // Not supported by OpenAI API
"diarization": true, // Post-processing via pyannote/sortformer pipeline
"translation": true,
"language_detection": true,
"vad": true, // Implicit
Expand All @@ -59,13 +59,19 @@ func NewOpenAIAdapter(apiKey string) *OpenAIAdapter {
Description: "OpenAI API Key (overrides system default)",
Group: "authentication",
},
{
Name: "base_url",
Type: "string",
Required: false,
Description: "Custom transcription API base URL (overrides server default)",
Group: "authentication",
},
{
Name: "model",
Type: "string",
Required: false,
Default: "whisper-1",
Options: []string{"whisper-1"},
Description: "ID of the model to use",
Description: "Model name (e.g. whisper-1, or any model exposed by a custom endpoint)",
Group: "basic",
},
{
Expand All @@ -92,6 +98,15 @@ func NewOpenAIAdapter(apiKey string) *OpenAIAdapter {
Description: "Sampling temperature",
Group: "quality",
},
{
Name: "timeout_minutes",
Type: "int",
Required: false,
Default: 10,
Min: &[]float64{1}[0],
Description: "HTTP request timeout in minutes (increase for large files on self-hosted endpoints)",
Group: "advanced",
},
}

baseAdapter := NewBaseAdapter("openai_whisper", "", capabilities, schema)
Expand Down Expand Up @@ -153,7 +168,14 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn
apiKey = key
}

if apiKey == "" {
const officialURL = "https://api.openai.com/v1/audio/transcriptions"
endpointURL := officialURL
if url := a.GetStringParameter(params, "base_url"); url != "" {
endpointURL = strings.TrimRight(url, "/") + "/audio/transcriptions"
}
isOfficialEndpoint := endpointURL == officialURL

if apiKey == "" && isOfficialEndpoint {
writeLog("Error: OpenAI API key is required but not provided")
return nil, fmt.Errorf("OpenAI API key is required but not provided")
}
Expand Down Expand Up @@ -188,7 +210,7 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn
writeLog("Model: %s", model)
_ = writer.WriteField("model", model)

if strings.HasPrefix(model, "gpt-4o") {
if isOfficialEndpoint && strings.HasPrefix(model, "gpt-4o") {
if strings.Contains(model, "diarize") {
_ = writer.WriteField("response_format", "diarized_json")
} else {
Expand All @@ -197,7 +219,6 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn
// gpt-4o models don't support timestamp_granularities with these formats
} else {
_ = writer.WriteField("response_format", "verbose_json")
// timestamp_granularities is only supported for whisper-1
if model == "whisper-1" {
_ = writer.WriteField("timestamp_granularities[]", "word") // Request word timestamps
_ = writer.WriteField("timestamp_granularities[]", "segment") // Request segment timestamps
Expand All @@ -224,8 +245,8 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn
}

// Create request
writeLog("Sending request to OpenAI API...")
req, err := http.NewRequestWithContext(ctx, "POST", "https://api.openai.com/v1/audio/transcriptions", body)
writeLog("Sending request to %s...", endpointURL)
req, err := http.NewRequestWithContext(ctx, "POST", endpointURL, body)
if err != nil {
writeLog("Error: Failed to create request: %v", err)
return nil, fmt.Errorf("failed to create request: %w", err)
Expand All @@ -235,9 +256,14 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn
req.Header.Set("Authorization", "Bearer "+apiKey)

// Execute request
client := &http.Client{
Timeout: 10 * time.Minute, // Generous timeout for large files
timeout := 10 * time.Minute
if !isOfficialEndpoint {
timeout = 30 * time.Minute // Default for self-hosted endpoints
}
if t := a.GetIntParameter(params, "timeout_minutes"); t > 0 {
timeout = time.Duration(t) * time.Minute
}
client := &http.Client{Timeout: timeout}
resp, err := client.Do(req)
if err != nil {
writeLog("Error: Request failed: %v", err)
Expand All @@ -247,8 +273,8 @@ func (a *OpenAIAdapter) Transcribe(ctx context.Context, input interfaces.AudioIn

if resp.StatusCode != http.StatusOK {
respBody, _ := io.ReadAll(resp.Body)
writeLog("Error: OpenAI API error (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("OpenAI API error (status %d): %s", resp.StatusCode, string(respBody))
writeLog("Error: transcription API error (status %d): %s", resp.StatusCode, string(respBody))
return nil, fmt.Errorf("transcription API error (status %d): %s", resp.StatusCode, string(respBody))
}

writeLog("Response received. Parsing...")
Expand Down
43 changes: 43 additions & 0 deletions internal/transcription/adapters_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,49 @@ func BenchmarkModelRegistryLookup(b *testing.B) {
}
}

func TestOpenAIAdapter(t *testing.T) {
a := adapters.NewOpenAIAdapter("sk-test")
if a == nil {
t.Fatal("NewOpenAIAdapter returned nil")
}

caps := a.GetCapabilities()
if caps.ModelID != "openai_whisper" {
t.Errorf("expected model ID 'openai_whisper', got %q", caps.ModelID)
}
if caps.ModelFamily != "openai" {
t.Errorf("expected model family 'openai', got %q", caps.ModelFamily)
}
if !caps.Features["diarization"] {
t.Error("diarization capability must be true")
}

schema := a.GetParameterSchema()
hasBaseURL := false
for _, p := range schema {
if p.Name == "base_url" {
hasBaseURL = true
}
if p.Name == "model" && len(p.Options) > 0 {
t.Errorf("model parameter must not have a fixed Options list, got %v", p.Options)
}
}
if !hasBaseURL {
t.Error("schema must include base_url parameter")
}
}

func TestOpenAIAdapterWithBaseURL(t *testing.T) {
a := adapters.NewOpenAIAdapter("")
if a == nil {
t.Fatal("NewOpenAIAdapter returned nil")
}
caps := a.GetCapabilities()
if !caps.Features["diarization"] {
t.Error("diarization capability must be true")
}
}

func BenchmarkParameterValidation(b *testing.B) {
reg := registry.GetRegistry()
adapter, err := reg.GetTranscriptionAdapter("whisperx")
Expand Down
6 changes: 6 additions & 0 deletions internal/transcription/unified_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,12 @@ func (u *UnifiedTranscriptionService) convertToOpenAIParams(params models.Whispe
if params.APIKey != nil && *params.APIKey != "" {
paramMap["api_key"] = *params.APIKey
}
if params.APIURL != nil && *params.APIURL != "" {
paramMap["base_url"] = *params.APIURL
}
if params.TimeoutMinutes != nil && *params.TimeoutMinutes > 0 {
paramMap["timeout_minutes"] = *params.TimeoutMinutes
}

return paramMap
}
Expand Down
22 changes: 20 additions & 2 deletions tests/adapter_registration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,19 @@ func TestAdapterEnvPathInjection(t *testing.T) {
}
}

// TestOpenAIAdapterConstruction tests the OpenAI adapter constructor
func TestOpenAIAdapterConstruction(t *testing.T) {
a := adapters.NewOpenAIAdapter("")
if a == nil {
t.Fatal("NewOpenAIAdapter returned nil with empty key")
}

a = adapters.NewOpenAIAdapter("sk-test")
if !a.GetCapabilities().Features["diarization"] {
t.Error("diarization capability must be true")
}
}

// TestRegisterAdapters tests that registerAdapters correctly registers all adapters
func TestRegisterAdapters(t *testing.T) {
// Clear registry before test
Expand All @@ -67,6 +80,8 @@ func TestRegisterAdapters(t *testing.T) {
adapters.NewParakeetAdapter(nvidiaEnvPath))
registry.RegisterTranscriptionAdapter("canary",
adapters.NewCanaryAdapter(nvidiaEnvPath))
registry.RegisterTranscriptionAdapter("openai_whisper",
adapters.NewOpenAIAdapter(""))

registry.RegisterDiarizationAdapter("pyannote",
adapters.NewPyAnnoteAdapter(nvidiaEnvPath))
Expand All @@ -75,8 +90,8 @@ func TestRegisterAdapters(t *testing.T) {

// Verify registrations
transcriptionAdapters := registry.GetTranscriptionAdapters()
if len(transcriptionAdapters) != 3 {
t.Errorf("Expected 3 transcription adapters, got %d", len(transcriptionAdapters))
if len(transcriptionAdapters) != 4 {
t.Errorf("Expected 4 transcription adapters, got %d", len(transcriptionAdapters))
}

// Check specific adapters are registered
Expand All @@ -89,6 +104,9 @@ func TestRegisterAdapters(t *testing.T) {
if _, exists := transcriptionAdapters["canary"]; !exists {
t.Error("canary adapter not registered")
}
if _, exists := transcriptionAdapters["openai_whisper"]; !exists {
t.Error("openai_whisper adapter not registered")
}

diarizationAdapters := registry.GetDiarizationAdapters()
if len(diarizationAdapters) != 2 {
Expand Down
Loading