برداری سازی در JAX



مطالعهٔ بهینه‌سازی محاسبات عددی برای کارایی بهینه، به ویژه در محاسبات علمی و یادگیری ماشین بسیار ضروری هست. در این مطلب، به بررسی روش‌های مختلف برای افزایش سرعت محاسبه با استفاده از کتابخانهٔ JAX، یک کتابخانهٔ برای محاسبات عددی و تفاضل‌پذیری خودکار، می‌پردازیم. روش های بردارسازی ساده، دستی با استفاده از einsum و خودکار با استفاده از vmap، را بررسی میکنیم و در نهایت، با استفاده از JIT، که Just-In-Time نامیده می‌شود، عملکرد تابع را بهبود میدیم.
نکته مورد توجه در این مطلب این هست که چطور میشه بدون دست زدن به تابع اولیه اون را برای اعمال روی ابعاد بالاتر مناسب سازی کنیم و در عین حال مراقب عملکرد بهینه محاسبات هم باشیم.
یک سری جاها برای نداشتن معادل فارسی مناسب یا انتقال بهتر مفهوم از کلمات انگلیسی استفاده کنم. در واقع بعضی ترجمه های فارسی احتمالا به گوشمون نا‌آشناست.
همراه باشین


def dot(v1, v2):
    return jax.numpy.dot(v1, v2)

Naively vectorizing - برداری سازی ساده با لیست

dot_naive =[dot(v1, v2) for v1, v2 in zip(v1s, v2s)]

اینجا به اصطلاح یک list comprehension داریم که محاسبات را برای هر یک از سطرهای بردار v1s و v2s انجام میدهد. فرض کردیم v1s یک آرایه دو بعدی هست. چیزی شبیه این:

from jax import random
rng_key = random.PRNGKey(42)
vs = random.normal(rng_key, shape=(1000,3))
v1s = vs[:500, :]
v2s = vs[500:, :]


Manual vectorizing - بردارسازی دستی

def dot_vectorized(v1s, v2s):
    return jnp.einsum(&quotij,ij->i&quot, v1s, v2s)

در این روش از تابع جمع اینشین einsum استفاده میکنیم که به عنوان چاقوی سویسی هم شناخته میشه و کاربردهای زیادی در کتابخانه numpy داره. لازمه برداری سازی دستی این هست که تابع اولیه را بازنویسی کنیم. اما آیا راهی وجود داره که بدون دست زدن به تابع اولیه اون رو برای اعمال روی بردار ها یا آرایه های با ابعاد بیشتر هم بهینه کنیم؟

Automatic vectorizing - برداری سازی خودکار

dot_vmapped = jax.vmap(dot)

برداری سازی با vmap اجازه میده تابع را برای batch های ورودی به کار ببریم. یعنی همون ورودی با آرایه های با ابعاد بیشتر که در اینجا دو بعد داریم. تابع vmap به طور پیشفرض بعد اول آرایه های ورودی را به عنوان batching dimension در نظر میگیره که البته قابل ویرایش هست.


زمان گیری


%timeit [dot(v1, v2) for v1, v2 in zip(v1s, v2s)]
%timeit dot_vectorized(v1s, v2s).block_until_ready()
%timeit dot_vmapped(v1s, v2s).block_until_ready()
5.15 ms ± 54.3 µs per loop
135 µs ± 171 ns per loop
543 µs ± 1.38 µs per loop

همون طور که میبینیم هنوز برداری سازی دستی عملکرد بهتری داره.

اما با اضافه کردن jit این اختلاف از بین میره


dot_vectorized_jitted = jax.jit(dot_vectorized)
dot_vmapped_jitted = jax.jit(dot_vmapped)
6.5 µs ± 12.9 ns per loop
6.39 µs ± 13.4 ns per loop

به طور خلاصه ترکیب استفاده از jit و vmap میتونه عملکرد تابع را به خوبی بهبود بده و اون رو بدون نیاز به ویرایش برای استفاده روی ورودی های با ابعاد بالاتر مناسب کنه.
برای دسترسی به نوتبوک می تونید به اینجا نگاه کنید.
این مطلب خلاصه ای بود از فصل vectorizing کتاب deep learning in jax.