|
@@ -0,0 +1,255 @@
|
|
|
+;;; guile-openai --- An OpenAI API client for Guile
|
|
|
+;;; Copyright © 2023 Andrew Whatson <whatson@tailcall.au>
|
|
|
+;;;
|
|
|
+;;; This file is part of guile-openai.
|
|
|
+;;;
|
|
|
+;;; guile-openai is free software: you can redistribute it and/or modify
|
|
|
+;;; it under the terms of the GNU Affero General Public License as
|
|
|
+;;; published by the Free Software Foundation, either version 3 of the
|
|
|
+;;; License, or (at your option) any later version.
|
|
|
+;;;
|
|
|
+;;; guile-openai is distributed in the hope that it will be useful, but
|
|
|
+;;; WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
|
+;;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
|
+;;; Affero General Public License for more details.
|
|
|
+;;;
|
|
|
+;;; You should have received a copy of the GNU Affero General Public
|
|
|
+;;; License along with guile-openai. If not, see
|
|
|
+;;; <https://www.gnu.org/licenses/>.
|
|
|
+
|
|
|
+(define-module (openai utils foreign)
|
|
|
+ #:use-module (ice-9 match)
|
|
|
+ #:use-module (ice-9 format)
|
|
|
+ #:use-module (ice-9 vlist)
|
|
|
+ #:use-module (srfi srfi-1)
|
|
|
+ #:use-module (srfi srfi-9)
|
|
|
+ #:use-module (srfi srfi-9 gnu)
|
|
|
+ #:use-module ((system foreign) #:prefix ffi:)
|
|
|
+ #:use-module ((system foreign) #:select (define-wrapped-pointer-type))
|
|
|
+ #:use-module (system foreign-library)
|
|
|
+ #:export (c-type?
|
|
|
+ c-type-name
|
|
|
+ c-type-size
|
|
|
+
|
|
|
+ int8 uint8 int16 uint16 int32 uint32 int64 uint64
|
|
|
+ float double complex-double complex-float
|
|
|
+ int unsigned-int long unsigned-long short unsigned-short
|
|
|
+ size_t ssize_t ptrdiff_t intptr_t uintptr_t
|
|
|
+ void pointer cstring bool
|
|
|
+
|
|
|
+ define-foreign-type
|
|
|
+ define-foreign-arg-type
|
|
|
+ define-foreign-return-type
|
|
|
+
|
|
|
+ define-foreign-enum-type
|
|
|
+ define-foreign-pointer-type
|
|
|
+
|
|
|
+ define-foreign-library
|
|
|
+ define-foreign-function
|
|
|
+ define-foreign-functions))
|
|
|
+
|
|
|
+;;; C type marshalling
|
|
|
+
|
|
|
+(define-record-type <c-type>
|
|
|
+ (%make-c-type name repr wrapper unwrapper)
|
|
|
+ c-type?
|
|
|
+ (name c-type-name)
|
|
|
+ (repr c-type-repr)
|
|
|
+ (wrapper c-type-wrapper)
|
|
|
+ (unwrapper c-type-unwrapper))
|
|
|
+
|
|
|
+(define* (print-c-type type #:optional port)
|
|
|
+ (format port "#<c-type ~a ~a>"
|
|
|
+ (c-type-name type)
|
|
|
+ (c-type-name (get-base-type (c-type-repr type)))))
|
|
|
+
|
|
|
+(define (c-type-size type)
|
|
|
+ (ffi:sizeof (c-type-repr type)))
|
|
|
+
|
|
|
+(set-record-type-printer! <c-type> print-c-type)
|
|
|
+
|
|
|
+(define-syntax-rule (define-foreign-type type-name base wrapper unwrapper)
|
|
|
+ (define type-name
|
|
|
+ (%make-c-type (symbol->string 'type-name)
|
|
|
+ (c-type-repr base)
|
|
|
+ wrapper unwrapper)))
|
|
|
+
|
|
|
+(define-syntax-rule (define-foreign-arg-type type-name base unwrapper)
|
|
|
+ (define-foreign-type type-name base #f unwrapper))
|
|
|
+
|
|
|
+(define-syntax-rule (define-foreign-return-type type-name base wrapper)
|
|
|
+ (define-foreign-type type-name base wrapper #f))
|
|
|
+
|
|
|
+;;; Base types
|
|
|
+
|
|
|
+(define %base-types vlist-null)
|
|
|
+
|
|
|
+(define (register-base-type! type)
|
|
|
+ (let ((repr (c-type-repr type)))
|
|
|
+ (unless (has-base-type? repr)
|
|
|
+ (set! %base-types (vhash-consv repr type %base-types)))))
|
|
|
+
|
|
|
+(define (has-base-type? repr)
|
|
|
+ (and (vhash-assv repr %base-types) #t))
|
|
|
+
|
|
|
+(define (get-base-type repr)
|
|
|
+ (match (vhash-assv repr %base-types)
|
|
|
+ ((_ . type) type)))
|
|
|
+
|
|
|
+(define-syntax-rule (define-base-type type-name repr)
|
|
|
+ (begin
|
|
|
+ (define type-name
|
|
|
+ (%make-c-type (symbol->string 'type-name) repr identity identity))
|
|
|
+ (register-base-type! type-name)))
|
|
|
+
|
|
|
+(define-base-type int8 ffi:int8)
|
|
|
+(define-base-type uint8 ffi:uint8)
|
|
|
+(define-base-type int16 ffi:int16)
|
|
|
+(define-base-type uint16 ffi:uint16)
|
|
|
+(define-base-type int32 ffi:int32)
|
|
|
+(define-base-type uint32 ffi:uint32)
|
|
|
+(define-base-type int64 ffi:int64)
|
|
|
+(define-base-type uint64 ffi:uint64)
|
|
|
+(define-base-type float ffi:float)
|
|
|
+(define-base-type double ffi:double)
|
|
|
+(define-base-type complex-double ffi:complex-double)
|
|
|
+(define-base-type complex-float ffi:complex-float)
|
|
|
+(define-base-type int ffi:int)
|
|
|
+(define-base-type unsigned-int ffi:unsigned-int)
|
|
|
+(define-base-type long ffi:long)
|
|
|
+(define-base-type unsigned-long ffi:unsigned-long)
|
|
|
+(define-base-type short ffi:short)
|
|
|
+(define-base-type unsigned-short ffi:unsigned-short)
|
|
|
+(define-base-type size_t ffi:size_t)
|
|
|
+(define-base-type ssize_t ffi:ssize_t)
|
|
|
+(define-base-type ptrdiff_t ffi:ptrdiff_t)
|
|
|
+(define-base-type intptr_t ffi:intptr_t)
|
|
|
+(define-base-type uintptr_t ffi:uintptr_t)
|
|
|
+(define-base-type void ffi:void)
|
|
|
+(define-base-type pointer '*)
|
|
|
+
|
|
|
+;;; Common types
|
|
|
+
|
|
|
+(define-foreign-type cstring pointer
|
|
|
+ ffi:pointer->string
|
|
|
+ ffi:string->pointer)
|
|
|
+
|
|
|
+(define-foreign-type bool int
|
|
|
+ (lambda (int) (not (zero? int)))
|
|
|
+ (lambda (bool) (if bool 1 0)))
|
|
|
+
|
|
|
+;;; Enum types
|
|
|
+
|
|
|
+(define-syntax-rule (define-foreign-enum-type enum-name enum-base
|
|
|
+ enumerator? enumerator-list
|
|
|
+ int->enumerator enumerator->int
|
|
|
+ (enumerator ...))
|
|
|
+ (begin
|
|
|
+ (define (enumerator? sym)
|
|
|
+ (and (enumerator->int sym) #t))
|
|
|
+ (define (enumerator-list)
|
|
|
+ (%dfe-enum-symbols (enumerator ...)))
|
|
|
+ (define enumerator->int
|
|
|
+ (let ((lookup (alist->vhash (map cons
|
|
|
+ (%dfe-enum-symbols (enumerator ...))
|
|
|
+ (%dfe-enum-values (enumerator ...)))
|
|
|
+ hashq)))
|
|
|
+ (lambda (sym)
|
|
|
+ (and=> (vhash-assq sym lookup) cdr))))
|
|
|
+ (define int->enumerator
|
|
|
+ (let ((lookup (alist->vhash (map cons
|
|
|
+ (%dfe-enum-values (enumerator ...))
|
|
|
+ (%dfe-enum-symbols (enumerator ...)))
|
|
|
+ hashv)))
|
|
|
+ (lambda (int)
|
|
|
+ (and=> (vhash-assv int lookup) cdr))))
|
|
|
+ (define-foreign-type enum-name enum-base
|
|
|
+ int->enumerator enumerator->int)))
|
|
|
+
|
|
|
+(define-syntax %dfe-enum-symbols
|
|
|
+ (syntax-rules (=>)
|
|
|
+ ((_ (args ...))
|
|
|
+ (%dfe-enum-symbols (args ...) ()))
|
|
|
+ ((_ (symbol => value args ...) (syms ...))
|
|
|
+ (%dfe-enum-symbols (args ...) (syms ... symbol)))
|
|
|
+ ((_ (symbol args ...) (syms ...))
|
|
|
+ (%dfe-enum-symbols (args ...) (syms ... symbol)))
|
|
|
+ ((_ () (syms ...))
|
|
|
+ '(syms ...))))
|
|
|
+
|
|
|
+(define-syntax %dfe-enum-values
|
|
|
+ (syntax-rules (=>)
|
|
|
+ ((_ (args ...))
|
|
|
+ (%dfe-enum-values (args ...) () -1))
|
|
|
+ ((_ (symbol => value args ...) (vals ...) previous)
|
|
|
+ (%dfe-enum-values (args ...) (vals ... value) value))
|
|
|
+ ((_ (symbol args ...) (vals ...) previous)
|
|
|
+ (%dfe-enum-values (args ...) (vals ... (1+ previous)) (1+ previous)))
|
|
|
+ ((_ () (vals ...) previous)
|
|
|
+ (list vals ...))))
|
|
|
+
|
|
|
+;;; Pointer types
|
|
|
+
|
|
|
+(define-syntax-rule (define-foreign-pointer-type pointer-name record-type
|
|
|
+ record? pointer->record record->pointer)
|
|
|
+ (begin
|
|
|
+ (define-wrapped-pointer-type record-type
|
|
|
+ record? pointer->record record->pointer
|
|
|
+ (lambda (rec port)
|
|
|
+ (let ((address (ffi:pointer-address (record->pointer rec))))
|
|
|
+ (format port "#<~a 0x~x>" 'pointer-name address))))
|
|
|
+ (define-foreign-type pointer-name pointer
|
|
|
+ pointer->record record->pointer)))
|
|
|
+
|
|
|
+;;; Function wrappers
|
|
|
+
|
|
|
+(define-syntax-rule (define-foreign-library library path args ...)
|
|
|
+ (define library
|
|
|
+ (load-foreign-library path args ...)))
|
|
|
+
|
|
|
+(define-syntax-rule (define-foreign-function library
|
|
|
+ (function-name signature ...))
|
|
|
+ (define function-name
|
|
|
+ (apply wrapped-foreign-library-function library
|
|
|
+ (symbol->string 'function-name)
|
|
|
+ (%dff-parse-signature (signature ...)))))
|
|
|
+
|
|
|
+(define-syntax %dff-parse-signature
|
|
|
+ (syntax-rules (->)
|
|
|
+ ((_ (-> return-type) arg-types ...)
|
|
|
+ (list #:return-type return-type
|
|
|
+ #:arg-types (list arg-types ...)))
|
|
|
+ ((_ (next rest ...) arg-types ...)
|
|
|
+ (%dff-parse-signature (rest ...) arg-types ... next))))
|
|
|
+
|
|
|
+(define-syntax-rule (define-foreign-functions library
|
|
|
+ (function-name signature ...) ...)
|
|
|
+ (begin
|
|
|
+ (define-foreign-function library
|
|
|
+ (function-name signature ...))
|
|
|
+ ...))
|
|
|
+
|
|
|
+(define (procedure-takes-rest? proc)
|
|
|
+ (caddr (procedure-minimum-arity proc)))
|
|
|
+
|
|
|
+(define* (wrapped-foreign-library-function library function-name
|
|
|
+ #:key return-type arg-types)
|
|
|
+ (let* ((wrapper (c-type-wrapper return-type))
|
|
|
+ (wrap-result (if (procedure-takes-rest? wrapper)
|
|
|
+ wrapper
|
|
|
+ (lambda (result . args)
|
|
|
+ (wrapper result))))
|
|
|
+ (unwrappers (map c-type-unwrapper arg-types))
|
|
|
+ (unwrap-args (lambda (args)
|
|
|
+ (map (lambda (unwrap arg)
|
|
|
+ (unwrap arg))
|
|
|
+ unwrappers args)))
|
|
|
+ (foreign-function
|
|
|
+ (foreign-library-function library function-name
|
|
|
+ #:return-type (c-type-repr return-type)
|
|
|
+ #:arg-types (map c-type-repr arg-types))))
|
|
|
+ (lambda args
|
|
|
+ (let* ((raw-args (unwrap-args args))
|
|
|
+ (raw-result (apply foreign-function raw-args))
|
|
|
+ (result (apply wrap-result raw-result args)))
|
|
|
+ result))))
|