| |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import datasets |
| import plotly.graph_objects as go |
| import numpy as np |
| import polars as pl |
|
|
|
|
| tokenizer = AutoTokenizer.from_pretrained("01-ai/Yi-34B", trust_remote_code=True) |
| alpaca = datasets.load_dataset("tatsu-lab/alpaca", split="train").map( |
| lambda ex: {"tokens": tokenizer(ex["text"])["input_ids"].__len__()}, num_proc=4 |
| ) |
|
|
|
|
| pdf = pl.DataFrame(alpaca.to_pandas()).with_columns(index=pl.int_range(0, pl.count())) |
| tokens = pdf["tokens"].to_numpy() |
|
|
| |
|
|
|
|
| def plot_batch(batch_size): |
| |
| data = pdf["tokens"].to_numpy().copy() |
| |
| data = data[:batch_size] |
| |
| max_value = max(data) |
|
|
| |
| fig = go.Figure() |
|
|
| |
| for i, value in enumerate(data): |
| fig.add_trace( |
| go.Bar( |
| x=[value], |
| y=[i + 1], |
| |
| orientation="h", |
| marker_color="blue", |
| ) |
| ) |
| fig.add_trace( |
| go.Bar( |
| x=[max_value - value], |
| y=[i + 1], |
| |
| orientation="h", |
| marker_color="red", |
| ) |
| ) |
|
|
| |
| fig.update_layout( |
| barmode="stack", |
| |
| |
| |
| showlegend=False, |
| xaxis=dict(range=[0, max_value]), |
| ) |
|
|
| |
| return fig |
|
|
|
|
| def packing(pocket=8192): |
| num_pocket = 0 |
| buffers = 0 |
|
|
| for token in tokens: |
| tmp_len = buffers + token |
| if tmp_len > pocket: |
| num_pocket += 1 |
| buffers = token |
| else: |
| buffers = tmp_len |
| if buffers: |
| num_pocket += 1 |
| return num_pocket * pocket / tokens.sum() |
|
|
|
|
| |
|
|
| plot_batch(30) |
|
|
| |
| arrs = [] |
| |
| for batch_size in range(1, 100): |
| arr = ( |
| pdf.with_columns( |
| batch=pl.col("tokens").max().over(pl.col("index") // batch_size) |
| ) |
| .select( |
| pl.col("tokens").sum().over(pl.col("index") // batch_size).mean(), |
| ((pl.col("batch")) / pl.col("tokens")).mean(), |
| ) |
| .to_numpy() |
| ) |
| arrs.append(arr) |
| x_values, y_values = np.concatenate(arrs).transpose() |
| pxs = np.linspace(tokens.max(), x_values[-1], 100) |
| pys = [packing(pocket) for pocket in pxs] |
|
|
|
|
| fig = go.Figure() |
| |
| fig.add_trace(go.Scatter(x=x_values, y=y_values, mode="lines", name="Batching")) |
|
|
|
|
| |
| fig.add_trace( |
| go.Scatter( |
| x=pxs, |
| y=pys, |
| mode="lines", |
| name="Packing", |
| |
| ) |
| ) |
|
|
| worst = tokens.max() / tokens.mean() |
| fig.add_trace( |
| go.Scatter( |
| x=x_values, |
| y=[worst] * len(x_values), |
| mode="lines", |
| name="Worst", |
| line=dict(dash="dash"), |
| ) |
| ) |
| fig.add_trace( |
| go.Scatter( |
| x=[8192], |
| y=[packing(8192)], |
| mode="markers", |
| name="Chosen", |
| |
| ) |
| ) |
| |
| |
| |
| |
| |
| |
| |
| |
| fig.update_layout( |
| |
| xaxis_title="throughput(tokens)", |
| yaxis_title="computational cost(ratio)", |
| yaxis=dict(range=[0, worst + 1]), |
| ) |
|
|
| |
|
|
| |
| fig.show() |
|
|