Skip to content

PyGPT-J API Reference

pygptj.model

This module contains a simple Python API around gpt-j

Model

Model(
    model_path,
    prompt_context="",
    prompt_prefix="",
    prompt_suffix="",
    log_level=logging.ERROR,
)

GPT-J model

Example usage

from pygptj.model import Model

model = Model(ggml_model='path/to/ggml/model')
for token in model.generate("Tell me a joke ?"):
    print(token, end='', flush=True)

Parameters:

Name Type Description Default
model_path str

The path to a gpt-j ggml model

required
prompt_context str

the global context of the interaction

''
prompt_prefix str

the prompt prefix

''
prompt_suffix str

the prompt suffix

''
log_level int

logging level

logging.ERROR
Source code in pygptj/model.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def __init__(self,
             model_path: str,
             prompt_context: str = '',
             prompt_prefix: str = '',
             prompt_suffix: str = '',
             log_level: int = logging.ERROR):
    """
    :param model_path: The path to a gpt-j `ggml` model
    :param prompt_context: the global context of the interaction
    :param prompt_prefix: the prompt prefix
    :param prompt_suffix: the prompt suffix
    :param log_level: logging level
    """
    # set logging level
    set_log_level(log_level)
    self._ctx = None

    if not Path(model_path).is_file():
        raise Exception(f"File {model_path} not found!")

    self.model_path = model_path

    self._model = pp.gptj_model()
    self._vocab = pp.gpt_vocab()

    # load model
    self._load_model()

    # gpt params
    self.gpt_params = pp.gptj_gpt_params()
    self.hparams = pp.gptj_hparams()

    self.res = ""

    self.logits = []

    self._n_past = 0
    self.prompt_cntext = prompt_context
    self.prompt_prefix = prompt_prefix
    self.prompt_suffix = prompt_suffix

    self._prompt_context_tokens = []
    self._prompt_prefix_tokens = []
    self._prompt_suffix_tokens = []

    self.reset()

generate

generate(
    prompt,
    n_predict=None,
    antiprompt=None,
    seed=None,
    n_threads=4,
    top_k=40,
    top_p=0.9,
    temp=0.9,
)

Runs GPT-J inference and yields new predicted tokens

Parameters:

Name Type Description Default
prompt str

The prompt :)

required
n_predict Union[None, int]

if n_predict is not None, the inference will stop if it reaches n_predict tokens, otherwise it will continue until end of text token

None
antiprompt str

aka the stop word, the generation will stop if this word is predicted, keep it None to handle it in your own way

None
seed int

random seed

None
n_threads int

The number of CPU threads

4
top_k int

top K sampling parameter

40
top_p float

top P sampling parameter

0.9
temp float

temperature

0.9

Returns:

Type Description
Generator

Tokens generator

Source code in pygptj/model.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
def generate(self,
             prompt: str,
             n_predict: Union[None, int] = None,
             antiprompt: str = None,
             seed: int = None,
             n_threads: int = 4,
             top_k: int = 40,
             top_p: float = 0.9,
             temp: float = 0.9,
             ) -> Generator:
    """
     Runs GPT-J inference and yields new predicted tokens

    :param prompt: The prompt :)
    :param n_predict: if n_predict is not None, the inference will stop if it reaches `n_predict` tokens, otherwise
                      it will continue until `end of text` token
    :param antiprompt: aka the stop word, the generation will stop if this word is predicted,
                       keep it None to handle it in your own way
    :param seed: random seed
    :param n_threads: The number of CPU threads
    :param top_k: top K sampling parameter
    :param top_p: top P sampling parameter
    :param temp: temperature

    :return: Tokens generator
    """
    if seed is None or seed < 0:
        seed = int(time.time())

    logging.info(f'seed = {seed}')

    if self._n_past == 0 or antiprompt is None:
        # add the prefix to the context
        embd_inp = self._prompt_prefix_tokens + pp.gpt_tokenize(self._vocab, prompt) + self._prompt_suffix_tokens
    else:
        # do not add the prefix again as it is already in the previous generated context
        embd_inp = pp.gpt_tokenize(self._vocab, prompt) + self._prompt_suffix_tokens

    if n_predict is not None:
        n_predict = min(n_predict, self.hparams.n_ctx - len(embd_inp))
    logging.info(f'Number of tokens in prompt = {len(embd_inp)}')

    embd = []
    # add global context for the first time
    if self._n_past == 0:
        for tok in self._prompt_context_tokens:
            embd.append(tok)

    # consume input tokens
    for tok in embd_inp:
        embd.append(tok)

    # determine the required inference memory per token:
    mem_per_token = 0
    logits, mem_per_token = pp.gptj_eval(self._model, n_threads, 0, [0, 1, 2, 3], mem_per_token)

    i = len(embd) - 1
    id = 0
    if antiprompt is not None:
        sequence_queue = []
        stop_word = antiprompt.strip()

    while id != 50256:  # end of text token
        if n_predict is not None:  # break the generation if n_predict
            if i >= (len(embd_inp) + n_predict):
                break
        i += 1
        # predict
        if len(embd) > 0:
            try:
                logits, mem_per_token = pp.gptj_eval(self._model, n_threads, self._n_past, embd, mem_per_token)
                self.logits.append(logits)
            except Exception as e:
                print(f"Failed to predict\n {e}")
                return

        self._n_past += len(embd)
        embd.clear()

        if i >= len(embd_inp):
            # sample next token
            n_vocab = self.hparams.n_vocab
            t_start_sample_us = int(round(time.time() * 1000000))
            id = pp.gpt_sample_top_k_top_p(self._vocab, logits[-n_vocab:], top_k, top_p, temp, seed)
            if id == 50256:  # end of text token
                break
            # add the token to the context
            embd.append(id)
            token = self._vocab.id_to_token[id]
            # antiprompt
            if antiprompt is not None:
                if token == '\n':
                    sequence_queue.append(token)
                    continue
                if len(sequence_queue) != 0:
                    if stop_word.startswith(''.join(sequence_queue).strip()):
                        sequence_queue.append(token)
                        if ''.join(sequence_queue).strip() == stop_word:
                            break
                        else:
                            continue
                    else:
                        # consume sequence queue tokens
                        while len(sequence_queue) != 0:
                            yield sequence_queue.pop(0)
                        sequence_queue = []

            yield token

cpp_generate

cpp_generate(
    prompt,
    new_text_callback=None,
    logits_callback=None,
    n_predict=128,
    seed=-1,
    n_threads=4,
    top_k=40,
    top_p=0.9,
    temp=0.9,
    n_batch=8,
)

Runs the inference to cpp generate function

Parameters:

Name Type Description Default
prompt str

the prompt

required
new_text_callback Callable[[str], None]

a callback function called when new text is generated, default None

None
logits_callback Callable[[np.ndarray], None]

a callback function to access the logits on every inference

None
n_predict int

number of tokens to generate

128
seed int

The random seed

-1
n_threads int

Number of threads

4
top_k int

top_k sampling parameter

40
top_p float

top_p sampling parameter

0.9
temp float

temperature sampling parameter

0.9
n_batch int

batch size for prompt processing

8

Returns:

Type Description
str

the new generated text

Source code in pygptj/model.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
def cpp_generate(self,
                 prompt: str,
                 new_text_callback: Callable[[str], None] = None,
                 logits_callback: Callable[[np.ndarray], None] = None,
                 n_predict: int = 128,
                 seed: int = -1,
                 n_threads: int = 4,
                 top_k: int = 40,
                 top_p: float = 0.9,
                 temp: float = 0.9,
                 n_batch: int = 8,
                 ) -> str:
    """
    Runs the inference to cpp generate function

    :param prompt: the prompt
    :param new_text_callback: a callback function called when new text is generated, default `None`
    :param logits_callback: a callback function to access the logits on every inference
    :param n_predict: number of tokens to generate
    :param seed: The random seed
    :param n_threads: Number of threads
    :param top_k: top_k sampling parameter
    :param top_p: top_p sampling parameter
    :param temp: temperature sampling parameter
    :param n_batch: batch size for prompt processing

    :return: the new generated text
    """
    self.gpt_params.prompt = prompt
    self.gpt_params.n_predict = n_predict
    self.gpt_params.seed = seed
    self.gpt_params.n_threads = n_threads
    self.gpt_params.top_k = top_k
    self.gpt_params.top_p = top_p
    self.gpt_params.temp = temp
    self.gpt_params.n_batch = n_batch

    # assign new_text_callback
    self.res = ""
    Model._new_text_callback = new_text_callback

    # assign _logits_callback used for saving logits, token by token
    Model._logits_callback = logits_callback

    # run the prediction
    pp.gptj_generate(self.gpt_params, self._model, self._vocab, self._call_new_text_callback,
                     self._call_logits_callback)
    return self.res

braindump

braindump(path)

Dumps the logits to .npy

Parameters:

Name Type Description Default
path str

Output path

required

Returns:

Type Description
None

None

Source code in pygptj/model.py
300
301
302
303
304
305
306
def braindump(self, path: str) -> None:
    """
    Dumps the logits to .npy
    :param path: Output path
    :return: None
    """
    np.save(path, np.asarray(self.logits))

reset

reset()

Resets the context

Returns:

Type Description
None

None

Source code in pygptj/model.py
308
309
310
311
312
313
314
315
316
def reset(self) -> None:
    """
    Resets the context
    :return: None
    """
    self._n_past = 0
    self._prompt_context_tokens = pp.gpt_tokenize(self._vocab, self.prompt_cntext)
    self._prompt_prefix_tokens = pp.gpt_tokenize(self._vocab, self.prompt_prefix)
    self._prompt_suffix_tokens = pp.gpt_tokenize(self._vocab, self.prompt_suffix)

get_params staticmethod

get_params(params)

Returns a dict representation of the params

Returns:

Type Description
dict

params dict

Source code in pygptj/model.py
318
319
320
321
322
323
324
325
326
327
328
329
@staticmethod
def get_params(params) -> dict:
    """
    Returns a `dict` representation of the params
    :return: params dict
    """
    res = {}
    for param in dir(params):
        if param.startswith('__'):
            continue
        res[param] = getattr(params, param)
    return res