Skip to content

Commit 78b70e2

Browse files
ashleyxuuGenesis929
authored andcommitted
fix: update the llm+kmeans notebook with recent change (#236)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://togithub.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes internal issue 313682530 🦕
1 parent f20d00b commit 78b70e2

File tree

1 file changed

+12
-35
lines changed

1 file changed

+12
-35
lines changed

notebooks/generative_ai/bq_dataframes_llm_kmeans.ipynb

+12-35
Original file line numberDiff line numberDiff line change
@@ -366,18 +366,6 @@
366366
"predicted_embeddings.head() "
367367
]
368368
},
369-
{
370-
"cell_type": "code",
371-
"execution_count": null,
372-
"metadata": {
373-
"id": "4H_etYfsEOFP"
374-
},
375-
"outputs": [],
376-
"source": [
377-
"# Join the complaints with their embeddings in the same DataFrame\n",
378-
"combined_df = downsampled_issues_df.join(predicted_embeddings)"
379-
]
380-
},
381369
{
382370
"attachments": {},
383371
"cell_type": "markdown",
@@ -426,30 +414,19 @@
426414
"outputs": [],
427415
"source": [
428416
"# Use KMeans clustering to calculate our groups. Will take ~3 minutes.\n",
429-
"cluster_model.fit(combined_df[[\"text_embedding\"]])\n",
430-
"clustered_result = cluster_model.predict(combined_df[[\"text_embedding\"]])\n",
417+
"cluster_model.fit(predicted_embeddings[[\"text_embedding\"]])\n",
418+
"clustered_result = cluster_model.predict(predicted_embeddings)\n",
431419
"# Notice the CENTROID_ID column, which is the ID number of the group that\n",
432420
"# each complaint belongs to.\n",
433421
"clustered_result.head(n=5)"
434422
]
435423
},
436-
{
437-
"cell_type": "code",
438-
"execution_count": null,
439-
"metadata": {},
440-
"outputs": [],
441-
"source": [
442-
"# Join the group number to the complaints and their text embeddings\n",
443-
"combined_clustered_result = combined_df.join(clustered_result)\n",
444-
"combined_clustered_result.head(n=5) "
445-
]
446-
},
447424
{
448425
"attachments": {},
449426
"cell_type": "markdown",
450427
"metadata": {},
451428
"source": [
452-
"Our dataframe combined_clustered_result now has three columns: the complaints, their text embeddings, and an ID from 1-10 (inclusive) indicating which semantically similar group they belong to."
429+
"Our dataframe combined_clustered_result now has three complaint columns: the content, their text embeddings, and an ID from 1-10 (inclusive) indicating which semantically similar group they belong to."
453430
]
454431
},
455432
{
@@ -480,14 +457,14 @@
480457
"source": [
481458
"# Using bigframes, with syntax identical to pandas,\n",
482459
"# filter out the first and second groups\n",
483-
"cluster_1_result = combined_clustered_result[\n",
484-
" combined_clustered_result[\"CENTROID_ID\"] == 1\n",
485-
"][[\"consumer_complaint_narrative\"]]\n",
460+
"cluster_1_result = clustered_result[\n",
461+
" clustered_result[\"CENTROID_ID\"] == 1\n",
462+
"][[\"content\"]]\n",
486463
"cluster_1_result_pandas = cluster_1_result.head(5).to_pandas()\n",
487464
"\n",
488-
"cluster_2_result = combined_clustered_result[\n",
489-
" combined_clustered_result[\"CENTROID_ID\"] == 2\n",
490-
"][[\"consumer_complaint_narrative\"]]\n",
465+
"cluster_2_result = clustered_result[\n",
466+
" clustered_result[\"CENTROID_ID\"] == 2\n",
467+
"][[\"content\"]]\n",
491468
"cluster_2_result_pandas = cluster_2_result.head(5).to_pandas()"
492469
]
493470
},
@@ -503,15 +480,15 @@
503480
"prompt1 = 'comment list 1:\\n'\n",
504481
"for i in range(5):\n",
505482
" prompt1 += str(i + 1) + '. ' + \\\n",
506-
" cluster_1_result_pandas[\"consumer_complaint_narrative\"].iloc[i] + '\\n'\n",
483+
" cluster_1_result_pandas[\"content\"].iloc[i] + '\\n'\n",
507484
"\n",
508485
"prompt2 = 'comment list 2:\\n'\n",
509486
"for i in range(5):\n",
510487
" prompt2 += str(i + 1) + '. ' + \\\n",
511-
" cluster_2_result_pandas[\"consumer_complaint_narrative\"].iloc[i] + '\\n'\n",
488+
" cluster_2_result_pandas[\"content\"].iloc[i] + '\\n'\n",
512489
"\n",
513490
"print(prompt1)\n",
514-
"print(prompt2)\n"
491+
"print(prompt2)"
515492
]
516493
},
517494
{

0 commit comments

Comments
 (0)