Skip to content

vllm.model_executor.models.gpt_oss

GptOssForCausalLM

Bases: Module, SupportsPP

Source code in vllm/model_executor/models/gpt_oss.py
class GptOssForCausalLM(nn.Module, SupportsPP):
    packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_substr={
            ".self_attn.": ".attn.",
        },
        orig_to_new_suffix={
            ".embed_tokens.weight": ".embedding.weight",

            # MoE MXFP4 weights
            ".gate_up_proj_blocks": ".w13_weight",
            ".down_proj_blocks": ".w2_weight",
            ".gate_up_proj_scales": ".w13_weight_scale",
            ".down_proj_scales": ".w2_weight_scale",

            # MoE other weights
            ".gate_up_proj": ".w13_weight",
            ".down_proj": ".w2_weight",

            # MoE Bias
            ".gate_up_proj_bias": ".w13_bias",
            ".down_proj_bias": ".w2_bias",
        },
    )

    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.vllm_config = vllm_config
        self.config = vllm_config.model_config.hf_config

        self.model = GptOssModel(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
        )
        self.lm_head = ParallelLMHead(
            self.config.vocab_size,
            self.config.hidden_size,
        )
        self.logits_processor = LogitsProcessor(self.config.vocab_size)
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
        return self.model(input_ids, positions, intermediate_tensors,
                          inputs_embeds)

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."]
                           if self.config.tie_word_embeddings else None),
        )
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

config instance-attribute

config = hf_config

hf_to_vllm_mapper class-attribute instance-attribute

hf_to_vllm_mapper = WeightsMapper(
    orig_to_new_substr={".self_attn.": ".attn."},
    orig_to_new_suffix={
        ".embed_tokens.weight": ".embedding.weight",
        ".gate_up_proj_blocks": ".w13_weight",
        ".down_proj_blocks": ".w2_weight",
        ".gate_up_proj_scales": ".w13_weight_scale",
        ".down_proj_scales": ".w2_weight_scale",
        ".gate_up_proj": ".w13_weight",
        ".down_proj": ".w2_weight",
        ".gate_up_proj_bias": ".w13_bias",
        ".down_proj_bias": ".w2_bias",
    },
)

lm_head instance-attribute

lm_head = ParallelLMHead(vocab_size, hidden_size)

logits_processor instance-attribute

logits_processor = LogitsProcessor(vocab_size)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors
)

model instance-attribute

model = GptOssModel(
    vllm_config=vllm_config,
    prefix=maybe_prefix(prefix, "model"),
)

packed_modules_mapping class-attribute instance-attribute

packed_modules_mapping = {
    "qkv": ["q_proj", "k_proj", "v_proj"]
}

vllm_config instance-attribute

vllm_config = vllm_config

__init__

__init__(vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/gpt_oss.py
def __init__(
    self,
    vllm_config: VllmConfig,
    prefix: str = "",
):
    super().__init__()
    self.vllm_config = vllm_config
    self.config = vllm_config.model_config.hf_config

    self.model = GptOssModel(
        vllm_config=vllm_config,
        prefix=maybe_prefix(prefix, "model"),
    )
    self.lm_head = ParallelLMHead(
        self.config.vocab_size,
        self.config.hidden_size,
    )
    self.logits_processor = LogitsProcessor(self.config.vocab_size)
    self.make_empty_intermediate_tensors = (
        self.model.make_empty_intermediate_tensors)

compute_logits

compute_logits(
    hidden_states: Tensor,
    sampling_metadata: SamplingMetadata,
) -> Tensor
Source code in vllm/model_executor/models/gpt_oss.py
def compute_logits(self, hidden_states: torch.Tensor,
                   sampling_metadata: SamplingMetadata) -> torch.Tensor:
    logits = self.logits_processor(self.lm_head, hidden_states,
                                   sampling_metadata)
    return logits

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
) -> Tensor
Source code in vllm/model_executor/models/gpt_oss.py
def forward(self,
            input_ids: torch.Tensor,
            positions: torch.Tensor,
            intermediate_tensors: Optional[IntermediateTensors] = None,
            inputs_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
    return self.model(input_ids, positions, intermediate_tensors,
                      inputs_embeds)

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/gpt_oss.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.model.get_input_embeddings(input_ids)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/gpt_oss.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    loader = AutoWeightsLoader(
        self,
        skip_prefixes=(["lm_head."]
                       if self.config.tie_word_embeddings else None),
    )
    return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

GptOssModel

Bases: Module

Source code in vllm/model_executor/models/gpt_oss.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
@support_torch_compile
class GptOssModel(nn.Module):

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
        self.cache_config = vllm_config.cache_config
        self.quant_config = vllm_config.quant_config
        self.parallel_config = vllm_config.parallel_config
        self.config.hidden_size = self.config.hidden_size
        self.embedding = VocabParallelEmbedding(
            self.config.vocab_size,
            self.config.hidden_size,
        )
        self.start_layer, self.end_layer, self.layers = make_layers(
            self.config.num_hidden_layers,
            lambda prefix: TransformerBlock(
                self.config,
                cache_config=self.cache_config,
                quant_config=self.quant_config,
                prefix=prefix,
            ),
            prefix=f"{prefix}.layers",
        )
        self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], self.config.hidden_size))

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embedding(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                x = inputs_embeds
            else:
                x = self.get_input_embeddings(input_ids)

            residual = None
        else:
            assert intermediate_tensors is not None
            x = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        for i in range(self.start_layer, self.end_layer):
            layer = self.layers[i]
            x, residual = layer(x, positions, residual)
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": x,
                "residual": residual
            })
        x, _ = self.norm(x, residual)
        return x

    def _load_weights_mxfp4(
        self,
        ep_rank_end: int,
        ep_rank_start: int,
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        mxfp4_block = 32
        use_ep = self.parallel_config.enable_expert_parallel
        num_experts = self.config.num_local_experts

        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()

        intermediate_size = self.config.intermediate_size
        intermediate_size_block = intermediate_size // mxfp4_block
        per_rank_intermediate_size_block = cdiv(intermediate_size_block,
                                                tp_size)
        per_rank_intermediate_size = (per_rank_intermediate_size_block *
                                      mxfp4_block)

        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
                          intermediate_size)

        for name, weight in weights:
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

            # FIXME(woosuk): Remove this after testing.
            weight = weight.cuda()

            if ".w13_weight_scale" in name:
                # Handle MLP gate and up projection weights scale
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:,
                                           2 * tp_rank_start:2 * tp_rank_end,
                                           ...]

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param,
                              narrow_weight,
                              weight_name=name,
                              shard_id=None,
                              expert_id=None)
                loaded_params.add(name)
                continue
            elif ".w2_weight_scale" in name:
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[..., tp_rank_start //
                                           mxfp4_block:tp_rank_end //
                                           mxfp4_block]

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param,
                              narrow_weight,
                              weight_name=name,
                              shard_id=None,
                              expert_id=None)
                loaded_params.add(name)
                continue
            elif ".w13_weight" in name:
                # Handle MLP gate and up projection weights
                # flat weight from (E, 2 * N, block_size, entry_per_block)
                # to (E, 2 * N, -1), shouldn't trigger copy for contiguous
                weight = weight.view(num_experts, 2 * intermediate_size,
                                     -1).contiguous()

                # Extract gate and up projection parts
                # since the weight is shuffled, we can slice directly
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:,
                                           2 * tp_rank_start:2 * tp_rank_end,
                                           ...]

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param,
                              narrow_weight,
                              weight_name=name,
                              shard_id=None,
                              expert_id=None)
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
                # Handle MLP down projection weights
                # same flatten here, but since 2 mx4 value are packed in 1
                # uint8, divide by 2
                weight = weight.view(num_experts, -1,
                                     intermediate_size // 2).contiguous()
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[...,
                                           tp_rank_start // 2:tp_rank_end // 2]

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param,
                              narrow_weight,
                              weight_name=name,
                              shard_id=None,
                              expert_id=None)
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:,
                                           2 * tp_rank_start:2 * tp_rank_end]

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param,
                              narrow_weight,
                              weight_name=name,
                              shard_id=None,
                              expert_id=None)
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
                # Handle MLP down projection bias
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
                weight_loader(param,
                              weight,
                              weight_name=name,
                              shard_id=None,
                              expert_id=None)
                loaded_params.add(name)
                continue
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
            else:
                # Handle all other weights with potential renaming
                if name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, weight)
            loaded_params.add(name)
        return loaded_params

    def _load_weights_other(
        self,
        ep_rank_start: int,
        ep_rank_end: int,
        heads_per_rank: int,
        head_start: int,
        weights: Iterable[tuple[str, torch.Tensor]],
        stacked_params_mapping: list[tuple[str, ...]],
    ) -> set[str]:
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()

        use_ep = self.parallel_config.enable_expert_parallel

        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()

        intermediate_size = self.config.intermediate_size
        per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
        # Calculate common slicing bounds for current rank
        tp_rank_start = tp_rank * per_rank_intermediate_size
        tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
                          intermediate_size)

        for name, weight in weights:
            # Skip layers on other devices.
            if is_pp_missing_parameter(name, self):
                continue

            if ".w13_weight" in name:
                # Handle MLP gate and up projection weights
                # Extract gate and up projection parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:, :,
                                           2 * tp_rank_start:2 * tp_rank_end]

                narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
                param = params_dict[name]

                param.copy_(narrow_weight)
                loaded_params.add(name)
                continue
            elif ".w2_weight" in name:
                # Handle MLP down projection weights
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
                narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
                param = params_dict[name]

                param.copy_(narrow_weight)
                loaded_params.add(name)
                continue
            elif ".w13_bias" in name:
                # Handle MLP gate and up projection biases
                # Extract gate and up projection bias parts
                if use_ep:
                    narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    narrow_weight = weight[:,
                                           2 * tp_rank_start:2 * tp_rank_end]

                param = params_dict[name]
                param.copy_(narrow_weight)
                loaded_params.add(name)
                continue
            elif ".w2_bias" in name:
                # Handle MLP down projection bias
                if use_ep:
                    weight = weight[ep_rank_start:ep_rank_end, ...]
                else:
                    # (only load on rank 0 to avoid duplication)
                    if tp_rank != 0:
                        weight.zero_()
                param = params_dict[name]
                param.copy_(weight)
                loaded_params.add(name)
                continue
            elif "sinks" in name:
                # Handle attention sinks (distributed across ranks)
                param = params_dict[name]
                narrow_weight = weight.narrow(0, head_start, heads_per_rank)
                param.data.copy_(narrow_weight)
                loaded_params.add(name)
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                if weight_loader == default_weight_loader:
                    weight_loader(param, weight)
                else:
                    weight_loader(param, weight, shard_id)
                break
            else:
                # Handle all other weights with potential renaming
                if name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, weight)
            loaded_params.add(name)
        return loaded_params

    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv", ".q_proj", "q"),
            (".qkv", ".k_proj", "k"),
            (".qkv", ".v_proj", "v"),
        ]

        tp_rank = get_tensor_model_parallel_rank()
        tp_size = get_tensor_model_parallel_world_size()

        # Attention heads per rank
        heads_per_rank = self.config.num_attention_heads // tp_size
        head_start = tp_rank * heads_per_rank

        ep_size = get_ep_group().world_size
        ep_rank = get_ep_group().rank
        num_experts = self.config.num_local_experts
        experts_per_rank = num_experts // ep_size
        ep_rank_start = ep_rank * experts_per_rank
        ep_rank_end = (ep_rank + 1) * experts_per_rank

        quant_method = (self.config.quantization_config['quant_method'] if
                        hasattr(self.config, "quantization_config") else None)
        if quant_method == "mxfp4":
            return self._load_weights_mxfp4(ep_rank_end, ep_rank_start,
                                            heads_per_rank, head_start,
                                            weights, stacked_params_mapping)
        else:
            return self._load_weights_other(ep_rank_end, ep_rank_start,
                                            heads_per_rank, head_start,
                                            weights, stacked_params_mapping)

cache_config instance-attribute

cache_config = cache_config

config instance-attribute

config = hf_config

embedding instance-attribute

embedding = VocabParallelEmbedding(vocab_size, hidden_size)

make_empty_intermediate_tensors instance-attribute

make_empty_intermediate_tensors = (
    make_empty_intermediate_tensors_factory(
        ["hidden_states", "residual"], hidden_size
    )
)

norm instance-attribute

norm = RMSNorm(hidden_size, eps=1e-05)

parallel_config instance-attribute

parallel_config = parallel_config

quant_config instance-attribute

quant_config = quant_config

__init__

__init__(*, vllm_config: VllmConfig, prefix: str = '')
Source code in vllm/model_executor/models/gpt_oss.py
def __init__(
    self,
    *,
    vllm_config: VllmConfig,
    prefix: str = "",
):
    super().__init__()
    self.config = vllm_config.model_config.hf_config
    self.cache_config = vllm_config.cache_config
    self.quant_config = vllm_config.quant_config
    self.parallel_config = vllm_config.parallel_config
    self.config.hidden_size = self.config.hidden_size
    self.embedding = VocabParallelEmbedding(
        self.config.vocab_size,
        self.config.hidden_size,
    )
    self.start_layer, self.end_layer, self.layers = make_layers(
        self.config.num_hidden_layers,
        lambda prefix: TransformerBlock(
            self.config,
            cache_config=self.cache_config,
            quant_config=self.quant_config,
            prefix=prefix,
        ),
        prefix=f"{prefix}.layers",
    )
    self.norm = RMSNorm(self.config.hidden_size, eps=1e-5)
    self.make_empty_intermediate_tensors = (
        make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], self.config.hidden_size))

_load_weights_mxfp4

_load_weights_mxfp4(
    ep_rank_end: int,
    ep_rank_start: int,
    heads_per_rank: int,
    head_start: int,
    weights: Iterable[tuple[str, Tensor]],
    stacked_params_mapping: list[tuple[str, ...]],
) -> set[str]
Source code in vllm/model_executor/models/gpt_oss.py
def _load_weights_mxfp4(
    self,
    ep_rank_end: int,
    ep_rank_start: int,
    heads_per_rank: int,
    head_start: int,
    weights: Iterable[tuple[str, torch.Tensor]],
    stacked_params_mapping: list[tuple[str, ...]],
) -> set[str]:
    params_dict = dict(self.named_parameters())
    loaded_params: set[str] = set()

    mxfp4_block = 32
    use_ep = self.parallel_config.enable_expert_parallel
    num_experts = self.config.num_local_experts

    tp_rank = get_tensor_model_parallel_rank()
    tp_size = get_tensor_model_parallel_world_size()

    intermediate_size = self.config.intermediate_size
    intermediate_size_block = intermediate_size // mxfp4_block
    per_rank_intermediate_size_block = cdiv(intermediate_size_block,
                                            tp_size)
    per_rank_intermediate_size = (per_rank_intermediate_size_block *
                                  mxfp4_block)

    # Calculate common slicing bounds for current rank
    tp_rank_start = tp_rank * per_rank_intermediate_size
    tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
                      intermediate_size)

    for name, weight in weights:
        # Skip layers on other devices.
        if is_pp_missing_parameter(name, self):
            continue

        # FIXME(woosuk): Remove this after testing.
        weight = weight.cuda()

        if ".w13_weight_scale" in name:
            # Handle MLP gate and up projection weights scale
            if use_ep:
                narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
            else:
                narrow_weight = weight[:,
                                       2 * tp_rank_start:2 * tp_rank_end,
                                       ...]

            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param,
                          narrow_weight,
                          weight_name=name,
                          shard_id=None,
                          expert_id=None)
            loaded_params.add(name)
            continue
        elif ".w2_weight_scale" in name:
            # Handle MLP down projection weights
            if use_ep:
                narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
            else:
                narrow_weight = weight[..., tp_rank_start //
                                       mxfp4_block:tp_rank_end //
                                       mxfp4_block]

            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param,
                          narrow_weight,
                          weight_name=name,
                          shard_id=None,
                          expert_id=None)
            loaded_params.add(name)
            continue
        elif ".w13_weight" in name:
            # Handle MLP gate and up projection weights
            # flat weight from (E, 2 * N, block_size, entry_per_block)
            # to (E, 2 * N, -1), shouldn't trigger copy for contiguous
            weight = weight.view(num_experts, 2 * intermediate_size,
                                 -1).contiguous()

            # Extract gate and up projection parts
            # since the weight is shuffled, we can slice directly
            if use_ep:
                narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
            else:
                narrow_weight = weight[:,
                                       2 * tp_rank_start:2 * tp_rank_end,
                                       ...]

            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param,
                          narrow_weight,
                          weight_name=name,
                          shard_id=None,
                          expert_id=None)
            loaded_params.add(name)
            continue
        elif ".w2_weight" in name:
            # Handle MLP down projection weights
            # same flatten here, but since 2 mx4 value are packed in 1
            # uint8, divide by 2
            weight = weight.view(num_experts, -1,
                                 intermediate_size // 2).contiguous()
            if use_ep:
                narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
            else:
                narrow_weight = weight[...,
                                       tp_rank_start // 2:tp_rank_end // 2]

            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param,
                          narrow_weight,
                          weight_name=name,
                          shard_id=None,
                          expert_id=None)
            loaded_params.add(name)
            continue
        elif ".w13_bias" in name:
            # Handle MLP gate and up projection biases
            # Extract gate and up projection bias parts
            if use_ep:
                narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
            else:
                narrow_weight = weight[:,
                                       2 * tp_rank_start:2 * tp_rank_end]

            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param,
                          narrow_weight,
                          weight_name=name,
                          shard_id=None,
                          expert_id=None)
            loaded_params.add(name)
            continue
        elif ".w2_bias" in name:
            # Handle MLP down projection bias
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            if use_ep:
                weight = weight[ep_rank_start:ep_rank_end, ...]
            else:
                # (only load on rank 0 to avoid duplication)
                if tp_rank != 0:
                    weight.zero_()
            weight_loader(param,
                          weight,
                          weight_name=name,
                          shard_id=None,
                          expert_id=None)
            loaded_params.add(name)
            continue
        elif "sinks" in name:
            # Handle attention sinks (distributed across ranks)
            param = params_dict[name]
            narrow_weight = weight.narrow(0, head_start, heads_per_rank)
            param.data.copy_(narrow_weight)
            loaded_params.add(name)
            continue
        for param_name, weight_name, shard_id in stacked_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            if weight_loader == default_weight_loader:
                weight_loader(param, weight)
            else:
                weight_loader(param, weight, shard_id)
            break
        else:
            # Handle all other weights with potential renaming
            if name not in params_dict:
                continue
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, weight)
        loaded_params.add(name)
    return loaded_params

_load_weights_other

_load_weights_other(
    ep_rank_start: int,
    ep_rank_end: int,
    heads_per_rank: int,
    head_start: int,
    weights: Iterable[tuple[str, Tensor]],
    stacked_params_mapping: list[tuple[str, ...]],
) -> set[str]
Source code in vllm/model_executor/models/gpt_oss.py
def _load_weights_other(
    self,
    ep_rank_start: int,
    ep_rank_end: int,
    heads_per_rank: int,
    head_start: int,
    weights: Iterable[tuple[str, torch.Tensor]],
    stacked_params_mapping: list[tuple[str, ...]],
) -> set[str]:
    params_dict = dict(self.named_parameters())
    loaded_params: set[str] = set()

    use_ep = self.parallel_config.enable_expert_parallel

    tp_rank = get_tensor_model_parallel_rank()
    tp_size = get_tensor_model_parallel_world_size()

    intermediate_size = self.config.intermediate_size
    per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
    # Calculate common slicing bounds for current rank
    tp_rank_start = tp_rank * per_rank_intermediate_size
    tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
                      intermediate_size)

    for name, weight in weights:
        # Skip layers on other devices.
        if is_pp_missing_parameter(name, self):
            continue

        if ".w13_weight" in name:
            # Handle MLP gate and up projection weights
            # Extract gate and up projection parts
            if use_ep:
                narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
            else:
                narrow_weight = weight[:, :,
                                       2 * tp_rank_start:2 * tp_rank_end]

            narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
            param = params_dict[name]

            param.copy_(narrow_weight)
            loaded_params.add(name)
            continue
        elif ".w2_weight" in name:
            # Handle MLP down projection weights
            if use_ep:
                narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
            else:
                narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
            narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
            param = params_dict[name]

            param.copy_(narrow_weight)
            loaded_params.add(name)
            continue
        elif ".w13_bias" in name:
            # Handle MLP gate and up projection biases
            # Extract gate and up projection bias parts
            if use_ep:
                narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
            else:
                narrow_weight = weight[:,
                                       2 * tp_rank_start:2 * tp_rank_end]

            param = params_dict[name]
            param.copy_(narrow_weight)
            loaded_params.add(name)
            continue
        elif ".w2_bias" in name:
            # Handle MLP down projection bias
            if use_ep:
                weight = weight[ep_rank_start:ep_rank_end, ...]
            else:
                # (only load on rank 0 to avoid duplication)
                if tp_rank != 0:
                    weight.zero_()
            param = params_dict[name]
            param.copy_(weight)
            loaded_params.add(name)
            continue
        elif "sinks" in name:
            # Handle attention sinks (distributed across ranks)
            param = params_dict[name]
            narrow_weight = weight.narrow(0, head_start, heads_per_rank)
            param.data.copy_(narrow_weight)
            loaded_params.add(name)
            continue
        for param_name, weight_name, shard_id in stacked_params_mapping:
            if weight_name not in name:
                continue
            name = name.replace(weight_name, param_name)
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            if weight_loader == default_weight_loader:
                weight_loader(param, weight)
            else:
                weight_loader(param, weight, shard_id)
            break
        else:
            # Handle all other weights with potential renaming
            if name not in params_dict:
                continue
            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, weight)
        loaded_params.add(name)
    return loaded_params

forward

forward(
    input_ids: Tensor,
    positions: Tensor,
    intermediate_tensors: Optional[
        IntermediateTensors
    ] = None,
    inputs_embeds: Optional[Tensor] = None,
) -> Tensor
Source code in vllm/model_executor/models/gpt_oss.py
def forward(
    self,
    input_ids: torch.Tensor,
    positions: torch.Tensor,
    intermediate_tensors: Optional[IntermediateTensors] = None,
    inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    if get_pp_group().is_first_rank:
        if inputs_embeds is not None:
            x = inputs_embeds
        else:
            x = self.get_input_embeddings(input_ids)

        residual = None
    else:
        assert intermediate_tensors is not None
        x = intermediate_tensors["hidden_states"]
        residual = intermediate_tensors["residual"]

    for i in range(self.start_layer, self.end_layer):
        layer = self.layers[i]
        x, residual = layer(x, positions, residual)
    if not get_pp_group().is_last_rank:
        return IntermediateTensors({
            "hidden_states": x,
            "residual": residual
        })
    x, _ = self.norm(x, residual)
    return x

get_input_embeddings

get_input_embeddings(input_ids: Tensor) -> Tensor
Source code in vllm/model_executor/models/gpt_oss.py
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
    return self.embedding(input_ids)

load_weights

load_weights(
    weights: Iterable[tuple[str, Tensor]],
) -> set[str]
Source code in vllm/model_executor/models/gpt_oss.py
def load_weights(self, weights: Iterable[tuple[str,
                                               torch.Tensor]]) -> set[str]:
    stacked_params_mapping = [
        # (param_name, shard_name, shard_id)
        (".qkv", ".q_proj", "q"),
        (".qkv", ".k_proj", "k"),
        (".qkv", ".v_proj", "v"),
    ]

    tp_rank = get_tensor_model_parallel_rank()
    tp_size = get_tensor_model_parallel_world_size()

    # Attention heads per rank
    heads_per_rank = self.config.num_attention_heads // tp_size
    head_start = tp_rank * heads_per_rank

    ep_size = get_ep_group().world_size
    ep_rank = get_ep_group().rank
    num_experts = self.config.num_local_experts
    experts_per_rank = num_experts // ep_size
    ep_rank_start = ep_rank * experts_per_rank
    ep_rank_end = (ep_rank + 1) * experts_per_rank

    quant_method = (self.config.quantization_config['quant_method'] if
                    hasattr(self.config, "quantization_config") else None)
    if quant_method == "mxfp4":
        return self._load_weights_mxfp4(ep_rank_end, ep_rank_start,
                                        heads_per_rank, head_start,
                                        weights, stacked_params_mapping)
    else:
        return self._load_weights_other(ep_rank_end, ep_rank_start,
                                        heads_per_rank, head_start,
                                        weights, stacked_params_mapping)

MLPBlock

Bases: Module

Source code in vllm/model_executor/models/gpt_oss.py
class MLPBlock(torch.nn.Module):

    def __init__(
        self,
        config: GptOssConfig,
        layer_idx: int,
        quant_config: QuantizationConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.layer_idx = layer_idx
        self.num_experts = config.num_local_experts
        self.experts_per_token = config.num_experts_per_tok
        self.world_size = dist.get_world_size() if dist.is_initialized() else 1
        self.router = torch.nn.Linear(config.hidden_size,
                                      config.num_local_experts,
                                      dtype=torch.bfloat16)
        assert config.intermediate_size % self.world_size == 0
        self.experts = FusedMoE(num_experts=config.num_local_experts,
                                top_k=config.num_experts_per_tok,
                                hidden_size=config.hidden_size,
                                intermediate_size=config.intermediate_size,
                                reduce_results=True,
                                renormalize=True,
                                quant_config=quant_config,
                                prefix=f"{prefix}.experts",
                                apply_router_weight_on_input=False,
                                has_bias=True,
                                activation="swigluoai")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        g = self.router(x)
        x = self.experts(hidden_states=x, router_logits=g)
        return x

experts instance-attribute

experts = FusedMoE(
    num_experts=num_local_experts,
    top_k=num_experts_per_tok,
    hidden_size=hidden_size,
    intermediate_size=intermediate_size,
    reduce_results=True,
    renormalize=True,
    quant_config=quant_config,
    prefix=f"{prefix}.experts",
    apply_router_weight_on_input=False,
    has_bias=True,
    activation="swigluoai",
)

experts_per_token instance-attribute

experts_per_token = num_experts_per_tok

layer_idx instance-attribute

layer_idx = layer_idx

num_experts instance-attribute

num_experts = num_local_experts

router instance-attribute

router = Linear(
    hidden_size, num_local_experts, dtype=bfloat16
)

world_size instance-attribute

world_size = get_world_size() if is_initialized() else 1

__init__

__init__(
    config: GptOssConfig,
    layer_idx: int,
    quant_config: QuantizationConfig,
    prefix: str = "",
)
Source code in vllm/model_executor/models/gpt_oss.py
def __init__(
    self,
    config: GptOssConfig,
    layer_idx: int,
    quant_config: QuantizationConfig,
    prefix: str = "",
):
    super().__init__()
    self.layer_idx = layer_idx
    self.num_experts = config.num_local_experts
    self.experts_per_token = config.num_experts_per_tok
    self.world_size = dist.get_world_size() if dist.is_initialized() else 1
    self.router = torch.nn.Linear(config.hidden_size,
                                  config.num_local_experts,
                                  dtype=torch.bfloat16)
    assert config.intermediate_size % self.world_size == 0
    self.experts = FusedMoE(num_experts=config.num_local_experts,
                            top_k=config.num_experts_per_tok,
                            hidden_size=config.hidden_size,
                            intermediate_size=config.intermediate_size,
                            reduce_results=True,
                            renormalize=True,
                            quant_config=quant_config,
                            prefix=f"{prefix}.experts",
                            apply_router_weight_on_input=False,
                            has_bias=True,
                            activation="swigluoai")

forward

forward(x: Tensor) -> Tensor
Source code in vllm/model_executor/models/gpt_oss.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    g = self.router(x)
    x = self.experts(hidden_states=x, router_logits=g)
    return x

OAIAttention

Bases: Module

Source code in vllm/model_executor/models/gpt_oss.py
class OAIAttention(nn.Module):

    def __init__(
        self,
        config: GptOssConfig,
        quant_config: Optional[QuantizationConfig] = None,
        cache_config: Optional[CacheConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.layer_idx = extract_layer_index(prefix)
        self.head_dim = config.head_dim
        self.num_attention_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.hidden_size = config.hidden_size

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=config.max_position_embeddings,
            base=config.rope_theta,
            dtype=torch.float32,
            rope_scaling={
                "rope_type":
                "yarn",
                "factor":
                config.rope_scaling["factor"],
                "original_max_position_embeddings":
                config.rope_scaling["original_max_position_embeddings"],
                "beta_fast":
                config.rope_scaling["beta_fast"],
                "beta_slow":
                config.rope_scaling["beta_slow"],
            },
            is_neox_style=True,
        )

        tp_size = get_tensor_model_parallel_world_size()

        self.sinks = torch.nn.Parameter(
            torch.empty(config.num_attention_heads // tp_size,
                        dtype=torch.bfloat16,
                        requires_grad=False))

        self.q_size = self.num_attention_heads * self.head_dim // tp_size
        self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
        self.scaling = self.head_dim**-0.5
        self.rope_theta = config.rope_theta

        self.qkv = QKVParallelLinear(
            hidden_size=self.hidden_size,
            head_size=self.head_dim,
            total_num_heads=self.num_attention_heads,
            total_num_kv_heads=self.num_key_value_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.o_proj = RowParallelLinear(
            input_size=self.num_attention_heads * self.head_dim,
            output_size=self.hidden_size,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.num_local_attention_heads = config.num_attention_heads // tp_size
        self.num_local_key_value_heads = config.num_key_value_heads // tp_size

        # Only apply sliding window to every other layer
        sliding_window = (config.sliding_window if self.layer_idx %
                          2 == 0 else None)
        self.attn = Attention(
            self.num_local_attention_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_local_key_value_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            per_layer_sliding_window=sliding_window,
            attn_type=AttentionType.DECODER,
            prefix=f"{prefix}.attn",
            sinks=self.sinks,
        )

    def forward(self, hidden_states: torch.Tensor,
                positions: torch.Tensor) -> torch.Tensor:
        qkv, _ = self.qkv(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        v = v.contiguous()
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output

attn instance-attribute

attn = Attention(
    num_local_attention_heads,
    head_dim,
    scaling,
    num_kv_heads=num_local_key_value_heads,
    cache_config=cache_config,
    quant_config=quant_config,
    per_layer_sliding_window=sliding_window,
    attn_type=DECODER,
    prefix=f"{prefix}.attn",
    sinks=sinks,
)

head_dim instance-attribute

head_dim = head_dim

hidden_size instance-attribute

hidden_size = hidden_size

kv_size instance-attribute

kv_size = num_key_value_heads * head_dim // tp_size

layer_idx instance-attribute

layer_idx = extract_layer_index(prefix)

num_attention_heads instance-attribute

num_attention_heads = num_attention_heads

num_key_value_heads instance-attribute

num_key_value_heads = num_key_value_heads

num_local_attention_heads instance-attribute

num_local_attention_heads = num_attention_heads // tp_size

num_local_key_value_heads instance-attribute

num_local_key_value_heads = num_key_value_heads // tp_size

o_proj instance-attribute

o_proj = RowParallelLinear(
    input_size=num_attention_heads * head_dim,
    output_size=hidden_size,
    quant_config=quant_config,
    prefix=f"{prefix}.o_proj",
)

q_size instance-attribute

q_size = num_attention_heads * head_dim // tp_size

qkv instance-attribute

qkv = QKVParallelLinear(
    hidden_size=hidden_size,
    head_size=head_dim,
    total_num_heads=num_attention_heads,
    total_num_kv_heads=num_key_value_heads,
    quant_config=quant_config,
    prefix=f"{prefix}.qkv_proj",
)

rope_theta instance-attribute

rope_theta = rope_theta

rotary_emb instance-attribute

rotary_emb = get_rope(
    head_dim,
    rotary_dim=head_dim,
    max_position=max_position_embeddings,
    base=rope_theta,
    dtype=float32,
    rope_scaling={
        "rope_type": "yarn",
        "factor": rope_scaling["factor"],
        "original_max_position_embeddings": rope_scaling[
            "original_max_position_embeddings"
        ],
        "beta_fast": rope_scaling["beta_fast"],
        "beta_slow": rope_scaling["beta_slow"],
    },
    is_neox_style=True,
)

scaling instance-attribute

scaling = head_dim ** -0.5

sinks instance-attribute

sinks = Parameter(
    empty(
        num_attention_heads // tp_size,
        dtype=bfloat16,
        requires_grad=False,
    )
)

__init__

__init__(
    config: GptOssConfig,
    quant_config: Optional[QuantizationConfig] = None,
    cache_config: Optional[CacheConfig] = None,
    prefix: str = "",
)
Source code in vllm/model_executor/models/gpt_oss.py
def __init__(
    self,
    config: GptOssConfig,
    quant_config: Optional[QuantizationConfig] = None,
    cache_config: Optional[CacheConfig] = None,
    prefix: str = "",
):
    super().__init__()
    self.layer_idx = extract_layer_index(prefix)
    self.head_dim = config.head_dim
    self.num_attention_heads = config.num_attention_heads
    self.num_key_value_heads = config.num_key_value_heads
    self.hidden_size = config.hidden_size

    self.rotary_emb = get_rope(
        self.head_dim,
        rotary_dim=self.head_dim,
        max_position=config.max_position_embeddings,
        base=config.rope_theta,
        dtype=torch.float32,
        rope_scaling={
            "rope_type":
            "yarn",
            "factor":
            config.rope_scaling["factor"],
            "original_max_position_embeddings":
            config.rope_scaling["original_max_position_embeddings"],
            "beta_fast":
            config.rope_scaling["beta_fast"],
            "beta_slow":
            config.rope_scaling["beta_slow"],
        },
        is_neox_style=True,
    )

    tp_size = get_tensor_model_parallel_world_size()

    self.sinks = torch.nn.Parameter(
        torch.empty(config.num_attention_heads // tp_size,
                    dtype=torch.bfloat16,
                    requires_grad=False))

    self.q_size = self.num_attention_heads * self.head_dim // tp_size
    self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
    self.scaling = self.head_dim**-0.5
    self.rope_theta = config.rope_theta

    self.qkv = QKVParallelLinear(
        hidden_size=self.hidden_size,
        head_size=self.head_dim,
        total_num_heads=self.num_attention_heads,
        total_num_kv_heads=self.num_key_value_heads,
        quant_config=quant_config,
        prefix=f"{prefix}.qkv_proj",
    )

    self.o_proj = RowParallelLinear(
        input_size=self.num_attention_heads * self.head_dim,
        output_size=self.hidden_size,
        quant_config=quant_config,
        prefix=f"{prefix}.o_proj",
    )

    self.num_local_attention_heads = config.num_attention_heads // tp_size
    self.num_local_key_value_heads = config.num_key_value_heads // tp_size

    # Only apply sliding window to every other layer
    sliding_window = (config.sliding_window if self.layer_idx %
                      2 == 0 else None)
    self.attn = Attention(
        self.num_local_attention_heads,
        self.head_dim,
        self.scaling,
        num_kv_heads=self.num_local_key_value_heads,
        cache_config=cache_config,
        quant_config=quant_config,
        per_layer_sliding_window=sliding_window,
        attn_type=AttentionType.DECODER,
        prefix=f"{prefix}.attn",
        sinks=self.sinks,
    )

forward

forward(hidden_states: Tensor, positions: Tensor) -> Tensor
Source code in vllm/model_executor/models/gpt_oss.py
def forward(self, hidden_states: torch.Tensor,
            positions: torch.Tensor) -> torch.Tensor:
    qkv, _ = self.qkv(hidden_states)
    q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
    q, k = self.rotary_emb(positions, q, k)
    v = v.contiguous()
    attn_output = self.attn(q, k, v)
    output, _ = self.o_proj(attn_output)
    return output

TransformerBlock

Bases: Module

Source code in vllm/model_executor/models/gpt_oss.py
class TransformerBlock(torch.nn.Module):

    def __init__(
        self,
        config: GptOssConfig,
        cache_config: CacheConfig,
        quant_config: QuantizationConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.layer_idx = extract_layer_index(prefix)
        self.attn = OAIAttention(config,
                                 prefix=f"{prefix}.attn",
                                 cache_config=cache_config)
        self.mlp = MLPBlock(config,
                            self.layer_idx,
                            quant_config=quant_config,
                            prefix=f"{prefix}.mlp")
        self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)

    def forward(
        self,
        hidden_states: torch.Tensor,
        positions: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.attn(hidden_states, positions)
        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        output = self.mlp(hidden_states)
        return output, residual

attn instance-attribute

attn = OAIAttention(
    config,
    prefix=f"{prefix}.attn",
    cache_config=cache_config,
)

input_layernorm instance-attribute

input_layernorm = RMSNorm(hidden_size, eps=1e-05)

layer_idx instance-attribute

layer_idx = extract_layer_index(prefix)

mlp instance-attribute

mlp = MLPBlock(
    config,
    layer_idx,
    quant_config=quant_config,
    prefix=f"{prefix}.mlp",
)

post_attention_layernorm instance-attribute

post_attention_layernorm = RMSNorm(hidden_size, eps=1e-05)

__init__

__init__(
    config: GptOssConfig,
    cache_config: CacheConfig,
    quant_config: QuantizationConfig,
    prefix: str = "",
)
Source code in vllm/model_executor/models/gpt_oss.py
def __init__(
    self,
    config: GptOssConfig,
    cache_config: CacheConfig,
    quant_config: QuantizationConfig,
    prefix: str = "",
):
    super().__init__()
    self.layer_idx = extract_layer_index(prefix)
    self.attn = OAIAttention(config,
                             prefix=f"{prefix}.attn",
                             cache_config=cache_config)
    self.mlp = MLPBlock(config,
                        self.layer_idx,
                        quant_config=quant_config,
                        prefix=f"{prefix}.mlp")
    self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
    self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)

forward

forward(
    hidden_states: Tensor,
    positions: Tensor,
    residual: Optional[Tensor],
) -> Tensor
Source code in vllm/model_executor/models/gpt_oss.py
def forward(
    self,
    hidden_states: torch.Tensor,
    positions: torch.Tensor,
    residual: Optional[torch.Tensor],
) -> torch.Tensor:
    # Self Attention
    if residual is None:
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
    else:
        hidden_states, residual = self.input_layernorm(
            hidden_states, residual)
    hidden_states = self.attn(hidden_states, positions)
    # Fully Connected
    hidden_states, residual = self.post_attention_layernorm(
        hidden_states, residual)
    output = self.mlp(hidden_states)
    return output, residual