Llama2 Transformer 网络结构

  • Transformer 结构
class Transformer(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() self.config = config self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) self.norm = RMSNorm(config.dim, eps=config.norm_eps) self.output = nn.Linear(config.dim, config.vocab_size, bias=False) self.freqs_cis: Optional[Tensor] = None self.mask_cache: Optional[Tensor] = None self.max_batch_size = -1 self.max_seq_length = -1
  • Embedding: vocab_size, dim
  • layers
  • RMSNorm
  • Linear