Android Nomad - #30

Exploring the Beauty of Shaders in Android with Jetpack Compose

Android Nomad - #30

Do you need shaders? Probably not, or may be!

Introduction

Shaders are the hidden gems of modern graphics development. These small, specialized programs can create mesmerizing visual effects, turning your ordinary UI into an extraordinary masterpiece. When combined with Jetpack Compose, the experience becomes even more immersive. In this blog post, we will explore how to harness the power of Shaders, specifically GLSL (OpenGL Shading Language), within Jetpack Compose.

Understanding Shaders

Shaders are small programs that run on the GPU (Graphics Processing Unit) and are used to manipulate the appearance of objects in a graphical environment. They can be used to create effects like blurs, glows, color transformations, and more. With Jetpack Compose, integrating shaders into your UI is easier than ever.

If you’re looking to start from scratch, I’ll highly recommend this book Learn OpenGL which serves as a good base.

Shader Toy has good resources for glsl that you can import in your project.

Lets build this screen in 3 steps

  • Get the glsl
  • GLSurfaceView for opengl rendering
  • Composable with GLSurfaceView

First, copy the following and add to raw directory under resources.

#version 100

#ifdef GL_ES
precision mediump float;
#endif

uniform float u_time;
uniform vec2 u_resolution;
uniform vec2 u_mouse;

mat2 rotate(float angle){
    return mat2(cos(angle), -sin(angle), sin(angle), cos(angle));
}

void main(){
    vec2 coord=(gl_FragCoord.xy/u_resolution)-.5;
    coord.x*=u_resolution.x/u_resolution.y;
    vec3 color=vec3(0.);

    vec2 newCoords=fract(coord)-.5;

    float flower=step(
    (sin(coord.x*100.)*cos(coord.y*100.+(u_time*5.)))*.27,
    coord.y
    );

    color=(flower*vec3(.9216, 1., .2235));

    gl_FragColor=vec4(color, 1.);
}

In this example we will be using a simple vertex shader

attribute vec4 a_Position;

void main (){
    gl_Position = a_Position;
}

Next we’re going to need a rendering class ShaderRenderer which extends GLSurfaceView.Render with the override methods:

onSurfaceCreated() — Create shader program, bind some uniform params, and send attributes to the vertex shader.

onDrawFrame() — Each frame update. In this method, we draw the screen quad and update uniform params if needed.

onSurfaceChanged() — Update the viewport.

I won’t go into detail describing how OpenGL works in detail because it’s out of the scope of this article. I also want to mention that we are focusing only on the fragment shader and are not concerned with the vertex shader’s details, as it should be the same for almost any possible fragment shader requirements.

open class ShaderRenderer : GLSurfaceView.Renderer {

    private val positionComponentCount = 2

    private val quadVertices by lazy {
        floatArrayOf(
            -1f, 1f,
            1f, 1f,
            -1f, -1f,
            1f, -1f
        )
    }

    private var surfaceHeight = 0f
    private var surfaceWidth = 0f

    private val bytesPerFloat = 4

    private val verticesData by lazy {
        ByteBuffer.allocateDirect(quadVertices.size * bytesPerFloat)
            .order(ByteOrder.nativeOrder()).asFloatBuffer().also {
                it.put(quadVertices)
            }
    }

    private var snapshotBuffer = initializeSnapshotBuffer(0, 0)

    private fun initializeSnapshotBuffer(width: Int, height: Int) = ByteBuffer.allocateDirect(
        width *
                height *
                bytesPerFloat
    ).order(ByteOrder.nativeOrder())

    override fun onSurfaceCreated(gl: GL10?, config: EGLConfig?) {
        GLES20.glClearColor(0f, 0f, 0f, 1f)
        GLES20.glDisable(GL10.GL_DITHER)
        GLES20.glHint(GL10.GL_PERSPECTIVE_CORRECTION_HINT, GL10.GL_FASTEST)
    }

    private val isProgramChanged = AtomicBoolean(false)

    private var programId: Int? = null

    private lateinit var fragmentShader: String
    private lateinit var vertexShader: String
    private lateinit var eventSource: String

    fun setShaders(fragmentShader: String, vertexShader: String, source: String = "") {
        this.fragmentShader = fragmentShader
        this.vertexShader = vertexShader
        this.eventSource = source
        shouldPlay.compareAndSet(false, true)
        isProgramChanged.compareAndSet(false, true)
    }

    private fun setupProgram() {
        programId?.let { GLES20.glDeleteProgram(it) }

        programId = GLES20.glCreateProgram().also { newProgramId ->
            if (programId == 0) {
                Timber.d("Could not create new program")
                return
            }

            val fragShader = createAndVerifyShader(fragmentShader, GLES20.GL_FRAGMENT_SHADER)
            val vertShader = createAndVerifyShader(vertexShader, GLES20.GL_VERTEX_SHADER)

            GLES20.glAttachShader(newProgramId, vertShader)
            GLES20.glAttachShader(newProgramId, fragShader)

            GLES20.glLinkProgram(newProgramId)

            val linkStatus = IntArray(1)
            GLES20.glGetProgramiv(newProgramId, GLES20.GL_LINK_STATUS, linkStatus, 0)

            if (linkStatus[0] == 0) {
                GLES20.glDeleteProgram(newProgramId)
                Timber.d("Linking of program failed. ${GLES20.glGetProgramInfoLog(newProgramId)}")
                return
            }

            if (validateProgram(newProgramId)) {
                positionAttributeLocation = GLES20.glGetAttribLocation(newProgramId, "a_Position")
                resolutionUniformLocation =
                    GLES20.glGetUniformLocation(newProgramId, "u_resolution")
                timeUniformLocation = GLES20.glGetUniformLocation(newProgramId, "u_time")
            } else {
                Timber.d("Validating of program failed.");
                return
            }

            verticesData.position(0)

            positionAttributeLocation?.let { attribLocation ->
                GLES20.glVertexAttribPointer(
                    attribLocation,
                    positionComponentCount,
                    GLES20.GL_FLOAT,
                    false,
                    0,
                    verticesData
                )
            }

            GLES20.glDetachShader(newProgramId, vertShader)
            GLES20.glDetachShader(newProgramId, fragShader)
            GLES20.glDeleteShader(vertShader)
            GLES20.glDeleteShader(fragShader)
        }
    }

    private var positionAttributeLocation: Int? = null
    private var resolutionUniformLocation: Int? = null
    private var timeUniformLocation: Int? = null


    override fun onSurfaceChanged(gl: GL10?, width: Int, height: Int) {
        GLES20.glViewport(0, 0, width, height)
        snapshotBuffer = initializeSnapshotBuffer(width, height)
        surfaceWidth = width.toFloat()
        surfaceHeight = height.toFloat()
        frameCount = 0f
    }

    private var frameCount = 0f

    override fun onDrawFrame(gl: GL10?) {
        if (shouldPlay.get()) {
            Trace.beginSection(eventSource)
            GLES20.glDisable(GL10.GL_DITHER)
            GLES20.glClear(GL10.GL_COLOR_BUFFER_BIT)


            if (isProgramChanged.getAndSet(false)) {
                setupProgram()
            } else {
                programId?.let {
                    GLES20.glUseProgram(it)
                } ?: return
            }

            positionAttributeLocation?.let {
                GLES20.glEnableVertexAttribArray(it)
            } ?: return


            resolutionUniformLocation?.let {
                GLES20.glUniform2f(it, surfaceWidth, surfaceHeight)
            }

            timeUniformLocation?.let {
                GLES20.glUniform1f(it, frameCount)
            }

            GLES20.glDrawArrays(GLES20.GL_TRIANGLE_STRIP, 0, 4)

            positionAttributeLocation?.let {
                GLES20.glDisableVertexAttribArray(it)
            } ?: return

            getPaletteCallback?.let { callback ->
                if (surfaceWidth != 0f && surfaceHeight != 0f) {
                    getCurrentBitmap()?.let { bitmap ->
                        Palette.Builder(bitmap)
                            .maximumColorCount(6)
                            .addTarget(Target.VIBRANT)
                            .generate().let { palette ->
                                callback(palette)
                                getPaletteCallback = null
                                bitmap.recycle()
                            }
                    }
                }
            }

            if (frameCount > 30) {
                frameCount = 0f
            }

            frameCount += 0.01f

            Trace.endSection()
        }
    }

    private fun getCurrentBitmap(): Bitmap? {
        val maxWidth = surfaceWidth.roundToInt()
        val maxHeight = surfaceHeight.roundToInt()

        val quarterWidth = maxWidth / 6
        val quarterHeight = maxHeight / 6

        val halfWidth = quarterWidth * 2
        val halfHeight = quarterHeight * 2

        initializeSnapshotBuffer(
            halfWidth * 2,
            halfHeight * 2,
        )

        GLES20.glReadPixels(
            halfWidth,
            halfHeight,
            halfWidth * 2,
            halfHeight * 2,
            GLES20.GL_RGBA,
            GLES20.GL_UNSIGNED_BYTE,
            snapshotBuffer
        )

        val bitmap = Bitmap.createBitmap(
            24,
            24,
            Bitmap.Config.ARGB_8888
        )

        bitmap.copyPixelsFromBuffer(snapshotBuffer)
        return bitmap
    }

    private fun validateProgram(programObjectId: Int): Boolean {
        GLES20.glValidateProgram(programObjectId)
        val validateStatus = IntArray(1)
        GLES20.glGetProgramiv(programObjectId, GLES20.GL_VALIDATE_STATUS, validateStatus, 0)

        Timber.tag("Results of validating").v(
            "${validateStatus[0]} \n  Log : ${
                GLES20.glGetProgramInfoLog(
                    programObjectId
                )
            } \n".trimIndent()
        )

        return validateStatus[0] != 0
    }

    private var getPaletteCallback: ((Palette) -> Unit)? = null

    fun setPaletteCallback(callback: (Palette) -> Unit) {
        getPaletteCallback = callback
    }

    private val shouldPlay = AtomicBoolean(false)

    fun onResume() {
        shouldPlay.compareAndSet(false, ::fragmentShader.isInitialized)
    }

    fun onPause() {
        shouldPlay.compareAndSet(true, false)
    }
}

ShaderGLSurfaceView

Great, everything that we need for our ShaderGLSurfaceView is ready. Now we can use the power of a fragment shader to render its content!

class ShaderGLSurfaceView @JvmOverloads constructor(
    context: Context,
    attrs: AttributeSet? = null,
) : GLSurfaceView(context, attrs) {

    init {
        // Create an OpenGL ES 2.0 context
        setEGLContextClientVersion(2)
        preserveEGLContextOnPause = true
    }

    private var hasSetShader = false

    fun setShaderRenderer(
        renderer: Renderer
    ) {
        if (hasSetShader.not())
            setRenderer(
                renderer
            )

        hasSetShader = true
    }

    override fun onResume() {
        super.onResume()
        Timber.d("ShaderGLSurfaceView onResume")
    }

    override fun onPause() {
        super.onPause()
        Timber.d("ShaderGLSurfaceView onPause")
    }

    override fun onDetachedFromWindow() {
        super.onDetachedFromWindow()
        Timber.d("ShaderGLSurfaceView onDetachedFromWindow")
    }
}

GLShader Composable

Great, everything looks set for us. Lets define our composable and then we are good to go.

@Composable
fun GLShader(
    renderer: ShaderRenderer,
    modifier: Modifier = Modifier
) {
    var view: ShaderGLSurfaceView? = remember { null }
    val lifeCycleState = LocalLifecycleOwner.current.lifecycle

    DisposableEffect(key1 = lifeCycleState) {
        val observer = LifecycleEventObserver { _, event ->
            when (event) {
                Lifecycle.Event.ON_RESUME -> {
                    view?.onResume()
                    renderer.onResume()
                }
                Lifecycle.Event.ON_PAUSE -> {
                    view?.onPause()
                    renderer.onPause()
                }
                else -> {
                }
            }
        }
        lifeCycleState.addObserver(observer)

        onDispose {
            Timber.d("View Disposed ${view.hashCode()}")
            lifeCycleState.removeObserver(observer)
            view?.onPause()
            view = null
        }
    }

    AndroidView(
        modifier = modifier,
        factory = {
            ShaderGLSurfaceView(it)
        }
    ) { glSurfaceView ->
        view = glSurfaceView
        glSurfaceView.debugFlags = GLSurfaceView.DEBUG_CHECK_GL_ERROR or GLSurfaceView.DEBUG_LOG_GL_CALLS
        glSurfaceView.setShaderRenderer(
            renderer
        )
    }
}

That’s all :) Now hit run and you’re good to go.

I’d highly recommend this course, if you’re looking to get started.

P.S. This is not my original work, ideas are borrowed from elsewhere.

Subscribe to Sid Pillai

Don’t miss out on the latest issues. Sign up now to get access to the library of members-only issues.
jamie@example.com
Subscribe