This commit is contained in:
2026-03-17 23:38:50 +01:00
parent 074db5e8f6
commit cfd761c70c
13 changed files with 48606 additions and 0 deletions

777
leadfinder/src/main.rs Normal file
View File

@@ -0,0 +1,777 @@
//TODO include [Y/n] confimrmation for API-calling commands and -y flag
//TODO Use tracing crate
use clap::{Parser, Subcommand, ValueEnum};
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
env,
fmt::{Debug, Display},
fs,
path::{Path, PathBuf},
process::ExitCode,
str::FromStr,
};
use tokio::time;
#[derive(Subcommand)]
enum Task {
/// Parses a given input file (.csv) and lists the contained companies. Panics if a `name` column does not exist.
List {
/// The csv file.
csv_file: String,
},
/// Parses a given input file (.json) and counts the number of occurences per company as a reference client of other companies.
/// Expects the json file structure given by the LLM prompts
/// (TODO: Not quite. If there are multiple prompt results in a single file, concatination of multiple results has to be done manually at the moment. Thats why we like helix tho.)
CountPartners {
/// The json file, containing the required format.
json_file: String,
},
/// WARNING: This might cost money to run!
/// Tests a given system prompt on a given model with the context prompt ""Research Target Company: Pandaloop\n Website: https://pandaloop.de"
TestWithPandaloop {
#[arg(short, long)]
model: Model,
#[arg(short, long)]
system_prompt_file: String,
#[arg(short, long)]
pretty: bool,
},
/// WARNING: This might cost money to run!
SinglePrompt {
idx: usize,
#[arg(short, long)]
csv_file: String,
#[arg(short, long)]
model: Model,
#[arg(short, long)]
system_prompt_file: String,
#[arg(short, long)]
pretty: bool,
},
/// WARNING: This might cost money to run!
RangePrompt {
range_start: usize,
range_end: usize,
#[arg(short, long)]
csv_file: String,
#[arg(short, long)]
model: Model,
#[arg(short, long)]
system_prompt_file: String,
#[arg(short, long)]
pretty: bool,
},
/// WARNING: This might cost A LOT money to run!
Multiprompt {
range_start: usize,
range_end: usize,
range_step: usize,
#[arg(short, long)]
csv_file: String,
#[arg(short, long)]
model: Model,
#[arg(short, long)]
system_prompt_file: String,
#[arg(short, long)]
pretty: bool,
},
}
#[derive(thiserror::Error, Debug)]
enum Error {
#[error("context index range is invalid: {0}")]
InvalidRange(String),
#[error("api key {0} does not exist in .env file")]
ApiKeyNotFound(Model),
#[error("Error parsing {filename:?} as csv file: {source}")]
CsvFile {
#[source]
source: csv::Error,
filename: Option<PathBuf>,
},
#[error("Error parsing json: {0}")]
SerdeJson(#[from] serde_json::Error),
#[error("Error parsing {filename:?} as json: {source}")]
SerdeJsonFile {
#[source]
source: serde_json::Error,
filename: Option<PathBuf>,
},
#[error("Error parsing json: \n\njson:\n{info}")]
SerdeJsonInfo {
#[source]
source: serde_json::Error,
info: String
},
#[error("reqwest Error: {0}")]
Reqwest(#[from] reqwest::Error),
#[error("io Error: {0}")]
Io(#[from] std::io::Error),
}
#[derive(Debug, Deserialize)]
struct CsvCompany {
#[serde(rename = "Company Name")]
name: String,
#[serde(rename = "Website")]
website: String,
}
#[allow(unused)]
#[derive(Debug, Deserialize)]
struct SerdeJsonTargetCompany {
#[serde(rename = "target_company")]
name: String,
connections: Vec<SerdeJsonPartnerCompany>,
}
#[allow(unused)]
#[derive(Debug, Deserialize)]
struct SerdeJsonPartnerCompany {
category: String,
context: String,
partner_name: String,
people: Vec<String>,
source_type: String,
}
#[allow(unused)]
#[derive(Debug, Deserialize)]
struct SerdeJsonContact {
value: String,
value_type: String,
category: String,
label: String,
source_url: String,
}
#[allow(unused)]
#[derive(Debug, Deserialize)]
struct SerdeJsonEmployee {
name: String,
role: String,
email: String,
phone: String,
linkedin_url: String,
source_url: String,
}
#[allow(unused)]
#[derive(Debug, Deserialize)]
struct SerdeJsonLead {
company_name: String,
website: String,
industry: String,
description: String,
employee_count: String,
general_contacts: Vec<SerdeJsonContact>,
employees: Vec<SerdeJsonContact>,
}
fn prettify(content: &str, model: Model) -> Result<String, Error> {
let ser_all: serde_json::Value = serde_json::from_str(content)?;
let cont = match model {
Model::Gemini => ser_all["candidates"][0]["content"]["parts"][0]["text"]
.as_str()
.unwrap()
.replace("\\\"", "\"")
.replace("\\n", "\n"),
Model::Perplexity => {
let raw_msg = ser_all["choices"][0]["message"]["content"]
.as_str()
.unwrap()
.replace("\\\"", "\"")
.replace("\\n", "\n");
let pref_stripped_msg = raw_msg.strip_prefix("```json").unwrap_or(&raw_msg);
let bidir_stripped_msg = pref_stripped_msg
.strip_suffix("```")
.unwrap_or(&pref_stripped_msg);
bidir_stripped_msg.to_owned()
}
};
eprintln!("{cont}");
let ser_cont: serde_json::Value =
serde_json::from_str(&cont).map_err(|source| Error::SerdeJson(source))?;
let cont_pretty = serde_json::to_string_pretty(&ser_cont)?;
Ok(cont_pretty)
}
#[derive(ValueEnum, Clone, Copy, Default, Debug, Serialize)]
enum Model {
#[default]
Gemini,
Perplexity,
}
impl Display for Model {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Model::Gemini => {
write!(f, "Google Gemini: 2.5 Flash")
}
Model::Perplexity => {
write!(f, "Perplexity AI: Sonar")
}
}
}
}
struct ApiKeys {
gemini_api_key: String,
perplexity_api_key: String,
}
#[derive(Parser)]
struct Args {
#[command(subcommand)]
task: Task,
}
fn parse_csv(file_path: &Path) -> Result<Vec<CsvCompany>, Error> {
let mut reader = csv::ReaderBuilder::new()
.has_headers(true)
.from_path(&file_path)
.map_err(|source| Error::CsvFile {
source,
filename: Some(file_path.to_owned()),
})?;
let mut companies = Vec::new();
for res in reader.deserialize() {
let record: CsvCompany = res.map_err(|source| Error::CsvFile {
source,
filename: Some(file_path.to_owned()),
})?;
companies.push(record);
}
Ok(companies)
}
async fn prompt(
model: Model,
reqclient: &reqwest::Client,
api_keys: &ApiKeys,
system_prompt: &str,
context_prompt: &str,
) -> Result<reqwest::Response, reqwest::Error> {
match model {
Model::Gemini => {
gemini_prompt(
reqclient,
&api_keys.gemini_api_key,
system_prompt,
context_prompt,
)
.await
}
Model::Perplexity => {
perplexity_prompt(
reqclient,
&api_keys.perplexity_api_key,
system_prompt,
context_prompt,
)
.await
}
}
}
async fn perplexity_prompt(
reqclient: &reqwest::Client,
perplexity_api_key: &str,
system_prompt: &str,
context_prompt: &str,
) -> Result<reqwest::Response, reqwest::Error> {
let body = format!(
r#"{{
"model": "sonar",
"messages": [
{{
"role": "system",
"content": "{system_prompt}"
}},
{{
"role": "user",
"content": "{context_prompt}"
}}
],
"temperature": 0
}}"#,
system_prompt = system_prompt.replace('"', "\\\"").replace('\n', "\\n"),
context_prompt = context_prompt.replace('"', "\\\"").replace('\n', "\\n")
);
reqclient
.post("https://api.perplexity.ai/chat/completions")
.header("accept", "application/json")
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {perplexity_api_key}"))
.body(body)
.send()
.await
}
async fn gemini_prompt(
reqclient: &reqwest::Client,
gemini_api_key: &str,
system_prompt: &str,
context_prompt: &str,
) -> Result<reqwest::Response, reqwest::Error> {
let body = format!(
r#"{{
"contents": [
{{
"parts": [
{{
"text": "SYSTEM PROMPT:\n\n\n{system_prompt}\n\nContext Prompt:{context_prompt}\n\n\n\nSYSTEM PROMPT:\n\n\n{system_prompt}\n\nContext Prompt:{context_prompt}\n"
}}
]
}}
],
"tools": [
{{
"google_search": {{}}
}}
],
"generationConfig": {{
"temperature": 0.0,
}}
}}"#,
system_prompt = system_prompt.replace('"', "\\\"").replace('\n', "\\n"),
context_prompt = context_prompt.replace('"', "\\\"").replace('\n', "\\n")
);
reqclient
.post("https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent")
.header("x-goog-api-key", format!("{gemini_api_key}"))
.header("Content-Type", "application/json")
.body(body)
.send()
.await
}
async fn run(args: Args) -> Result<(), Error> {
let _ = dotenv::dotenv();
let mut vars = env::vars().collect::<HashMap<String, String>>();
let client = reqwest::Client::new();
let gemini_key = match vars.remove("GEMINI_API_KEY") {
Some(key) => key,
None => {
return Err(Error::ApiKeyNotFound(Model::Gemini));
}
};
let perplexity_key = match vars.remove("PERPLEXITY_API_KEY") {
Some(key) => key,
None => {
return Err(Error::ApiKeyNotFound(Model::Perplexity));
}
};
let api_keys = ApiKeys {
gemini_api_key: gemini_key,
perplexity_api_key: perplexity_key,
};
match args.task {
Task::List { csv_file } => {
let pathbuf = PathBuf::from_str(&csv_file).unwrap_or_else(|_| unreachable!());
let csv = parse_csv(&pathbuf)?;
for (idx, company) in csv.iter().enumerate() {
println!(
"n = {: >5} | Company = {:?} ({:?})",
idx, company.name, company.website
);
}
}
Task::CountPartners { json_file } => {
let pathbuf = PathBuf::from_str(&json_file).unwrap_or_else(|_| unreachable!());
let datafile = fs::read_to_string(json_file).unwrap();
let companies: Vec<SerdeJsonTargetCompany> =
serde_json::from_str(&datafile).map_err(|source| Error::SerdeJsonFile {
source,
filename: Some(pathbuf),
})?;
let mut partner_occurrences = HashMap::new();
for target_company in companies {
let connections_array = target_company.connections;
for connection in connections_array {
if connection.category == "REFERENCE_CLIENT" {
partner_occurrences
.entry(connection.partner_name)
.and_modify(|count| *count += 1)
.or_insert(1);
}
}
}
let mut count_vec: Vec<_> = partner_occurrences.iter().collect();
count_vec.sort_by(|a, b| b.1.cmp(a.1));
for (k, v) in count_vec {
println!("Company {:>60} appears {:>3} times", k, v);
}
}
Task::TestWithPandaloop {
model,
pretty,
system_prompt_file,
} => {
let system_prompt =
fs::read_to_string(system_prompt_file).unwrap_or_else(|_| unreachable!());
let context_prompt =
"Research Target Company T: Pandaloop\nWebsite: https://pandaloop.de\n".to_owned();
let mut response = prompt(model, &client, &api_keys, &system_prompt, &context_prompt)
.await?
.text()
.await?;
if pretty {
response = prettify(&response, model)?;
}
println!("{response}");
}
Task::SinglePrompt {
idx,
csv_file,
model,
system_prompt_file,
pretty,
} => {
let pathbuf = PathBuf::from_str(&csv_file).unwrap_or_else(|_| unreachable!());
let csv = parse_csv(&pathbuf)?;
let system_prompt = fs::read_to_string(system_prompt_file)?;
if let Some(company) = csv.get(idx) {
let context_prompt = format!(
"Research Target Company: {}\nWebsite: {}\n",
company.name, company.website
);
let mut response =
prompt(model, &client, &api_keys, &system_prompt, &context_prompt)
.await?
.text()
.await?;
if pretty {
response = prettify(&response, model)?;
}
println!("{response}");
} else {
eprintln!("Index {:?} out of range, skipping.", idx);
}
}
Task::RangePrompt {
range_start,
range_end,
csv_file,
model,
system_prompt_file,
pretty,
} => {
let pathbuf = PathBuf::from_str(&csv_file).unwrap_or_else(|_| unreachable!());
let csv = parse_csv(&pathbuf)?;
let system_prompt = fs::read_to_string(system_prompt_file)?;
if range_start >= range_end {
return Err(Error::InvalidRange(format!(
"range args {range_start}..{range_end} are invalid: range_start should be smaller than range_end"
)));
}
let mut context_prompt = String::new();
for i in range_start..=range_end {
if let Some(company) = csv.get(i) {
context_prompt = format!(
"{}\nResearch Target Company T: {}\nWebsite: {}\nTask: Find all companies that are partners or clients of T based on the research guidelines.",
context_prompt, company.name, company.website
);
} else {
eprintln!("Index {:?} out of range, skipping.", i);
}
}
let mut response = prompt(model, &client, &api_keys, &system_prompt, &context_prompt)
.await?
.text()
.await?;
if pretty {
response = prettify(&response, model)?;
}
println!("{response}");
}
Task::Multiprompt {
range_start,
range_end,
range_step,
csv_file,
model,
system_prompt_file,
pretty,
} => {
let pathbuf = PathBuf::from_str(&csv_file).unwrap_or_else(|_| unreachable!());
let csv = parse_csv(&pathbuf)?;
let system_prompt = fs::read_to_string(system_prompt_file)?;
if range_start >= range_end {}
let mut i = range_start;
while i <= range_end {
let mut context_prompt = String::new();
for j in i..i + range_step {
if let Some(company) = csv.get(j) {
context_prompt = format!(
"{}\nResearch Target Company T: {}\nWebsite: {}\n\n\n",
context_prompt, company.name, company.website
);
} else {
eprintln!("Index {:?} out of range, skipping.", j);
}
}
eprintln!(
"Dispatching prompt for range {}..={}",
i,
i + range_step - 1
);
let response_result =
prompt(model, &client, &api_keys, &system_prompt, &context_prompt).await;
match response_result {
Ok(response) => {
if response.status() == reqwest::StatusCode::TOO_MANY_REQUESTS {
eprintln!("429 reached. Go to bed.");
return Ok(());
} else {
}
let mut response_text = response.text().await?;
if pretty {
response_text = prettify(&response_text, model)?;
}
println!("{response_text}");
i += range_step;
}
Err(error) => {
eprintln!("{}", error);
let _ = time::sleep(time::Duration::from_millis(5000));
}
}
}
}
}
Ok(())
// match args {
// Args {
// list: true,
// test: _,
// index: _,
// range: _,
// multiprompt_range: _,
// output_file: _,
// model: _,
// pretty: _,
// count_partners: _,
// } => {
// for (idx, company) in csv.iter().enumerate() {
// println!("n = {: >5} | Company = {:?}", idx, company.name);
// }
// }
// Args {
// list: _,
// test: _,
// index: _,
// range: _,
// multiprompt_range: _,
// output_file: _,
// model: _,
// pretty: _,
// count_partners: true,
// } => {
// // TODO not hard-coded
// let mut partner_occurrences = HashMap::new();
// let datafile =
// fs::read_to_string("/home/jan/pl/qwertus/leadfinder/out/perp_out.json").unwrap();
// let json: serde_json::Value = serde_json::from_str(&datafile)?;
// let company_array = json.as_array().unwrap_or_else(|| panic!("U fucked up! json file to inspect should have array of companies as first level value."));
// for target_company in company_array {
// let connections_array = target_company["connections"].as_array().unwrap_or_else(|| panic!("U fucked up! json file to inspect should have array of companies as first level value."));
// for connection in connections_array {
// if connection["category"].as_str().unwrap_or_else(|| panic!("Ur AI probably fucked up: connection company with no partner_name in file!")) == "REFERENCE_CLIENT" {
// partner_occurrences
// .entry(connection["partner_name"].as_str().unwrap_or_else(|| panic!("Ur AI probably fucked up: connection company with no partner_name in file!")))
// .and_modify(|count| *count += 1)
// .or_insert(1);
// }
// }
// }
// let mut count_vec: Vec<_> = partner_occurrences.iter().collect();
// count_vec.sort_by(|a, b| b.1.cmp(a.1));
// for (k, v) in count_vec {
// println!("Company {:>60} appears {:>3} times", k, v);
// }
// }
// Args {
// test: true,
// list: _,
// index: _,
// range: _,
// multiprompt_range: _,
// output_file: _,
// model,
// pretty,
// count_partners: _,
// } => {
// let context_prompt = "Research Target Company T: Pandaloop\n\
// Website: https://pandaloop.de\n\
// "
// .to_owned();
// let mut response = prompt(model, &client, &api_keys, &system_prompt, &context_prompt)
// .await?
// .text()
// .await?;
// if pretty {
// response = prettify(&response, model)?;
// }
// file.write_all(&response.as_bytes())?;
// }
// Args {
// index: Some(i),
// list: _,
// test: _,
// range: _,
// multiprompt_range: _,
// output_file: _,
// model,
// pretty,
// count_partners: _,
// } => {
// if let Some(company) = csv.get(i) {
// let context_prompt = format!(
// "Research Target Company T: {}\nWebsite: {}\nTask: Find all companies that are partners or clients of T based on the research guidelines.",
// company.name, company.website
// );
// let mut response =
// prompt(model, &client, &api_keys, &system_prompt, &context_prompt)
// .await?
// .text()
// .await?;
// if pretty {
// response = prettify(&response, model)?;
// }
// file.write_all(&response.as_bytes())?;
// } else {
// println!("Index {:?} out of range, skipping.", i);
// }
// }
// Args {
// list: _,
// test: _,
// index: _,
// range: Some(range),
// multiprompt_range: _,
// output_file: _,
// model,
// pretty,
// count_partners: _,
// } => {
// if range[0] >= range[1] {
// println!("Range start larger than range end. Doing nothing.");
// return Ok(());
// }
// let mut context_prompt = String::new();
// for i in range[0]..=range[1] {
// if let Some(company) = csv.get(i) {
// context_prompt = format!(
// "{}\nResearch Target Company T: {}\nWebsite: {}\nTask: Find all companies that are partners or clients of T based on the research guidelines.",
// context_prompt, company.name, company.website
// );
// } else {
// println!("Index {:?} out of range, skipping.", i);
// }
// }
// let mut response = prompt(model, &client, &api_keys, &system_prompt, &context_prompt)
// .await?
// .text()
// .await?;
// if pretty {
// response = prettify(&response, model)?;
// }
// file.write_all(&response.as_bytes())?;
// }
// Args {
// list: _,
// test: _,
// index: _,
// range: _,
// multiprompt_range: Some(multirange),
// output_file: _,
// model,
// pretty,
// count_partners: _,
// } => {
// if multirange[0] >= multirange[1] {
// println!("Range start larger than range end. Doing nothing.");
// return Ok(());
// }
// let mut i = multirange[0];
// while i <= multirange[1] {
// let mut context_prompt = String::new();
// for j in i..i + multirange[2] {
// if let Some(company) = csv.get(j) {
// context_prompt = format!(
// "{}\nResearch Target Company T: {}\nWebsite: {}\nTask: Find all companies that are partners or clients of T based on the research guidelines.\n\n",
// context_prompt, company.name, company.website
// );
// } else {
// println!("Index {:?} out of range, skipping.", j);
// }
// }
// println!(
// "Dispatching prompt for range {}..={}",
// i,
// i + multirange[2] - 1
// );
// let response_result =
// prompt(model, &client, &api_keys, &system_prompt, &context_prompt).await;
// match response_result {
// Ok(response) => {
// if response.status() == reqwest::StatusCode::TOO_MANY_REQUESTS {
// println!("429 reached. Go to bed.");
// return Ok(());
// } else {
// }
// let mut response_text = response.text().await?;
// if pretty {
// response_text = prettify(&response_text, model)?;
// }
// file.write_all(&response_text.as_bytes())?;
// i += multirange[2];
// }
// Err(error) => {
// println!("{}", error);
// let _ = time::sleep(time::Duration::from_millis(5000));
// }
// }
// }
// }
// _ => unreachable!(),
// }
// - denkwerk Partner: *fraenk*, *OTTO*, *SWR*, *Union Investment*, *ARAG*, erenja, BurdaForward, Stiebel Eltron, edding, polymore, Telekom, Santander, DeepL, motel one, Struggly, Fondation Beyeler, Charite, mainova, Esprit, Sport Cast, Remondis, condor, , Aktion Mensch, easy credit, Sparkasse KölnBonn, multiloop, Storck, Teambank, Ranger, Microsoft
}
#[tokio::main]
async fn main() -> ExitCode {
let args = Args::parse();
match run(args).await {
Ok(_) => ExitCode::SUCCESS,
Err(e) => {
// match e {
// Error::CsvFile { source, filename } => match source.kind() {
// csv::ErrorKind::Io(e) => {
// if e.kind() == std::io::ErrorKind::NotFound {
// eprintln!("File {filename:?} does not exist.");
// }
// }
// _ => {},
// },
// _ => eprintln!("{e}"),
// }
eprintln!("{e}");
ExitCode::FAILURE
}
}
}