oops forgot to subtract embedding params, which don't enter the 6ND equation

pull/122/head
Andrej Karpathy 2023-02-04 22:33:35 +00:00
parent 5a162bc773
commit 3341b4cecc
1 changed files with 4 additions and 2 deletions

View File

@ -267,7 +267,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"palm_flops: 879894724608, flops: 874944921600, ratio: 1.0057\n"
"palm_flops: 875062886400, flops: 874944921600, ratio: 1.0001\n"
]
}
],
@ -276,7 +276,9 @@
"# this formula is often used to calculate MFU (model flops utilization)\n",
"def palm_flops():\n",
" \"\"\"estimate of the model flops following PaLM paper formula\"\"\"\n",
" N = params()['total']\n",
" # non-embedding model parameters. note that we do not subtract the\n",
" # embedding/token params because those are tied and get used in the last layer.\n",
" N = params()['total'] - params()['emebedding/position']\n",
" L, H, Q, T = n_layer, n_head, n_embd//n_head, block_size\n",
" mf_per_token = 6*N + 12*L*H*Q*T\n",
" mf = mf_per_token * block_size\n",